mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Core] move transformer scripts to transformers modules (#6747)
* move transformer scripts to transformers modules * move transformer model test * move prior transformer test to directory * fix doc path * correct doc path * add: __init__.py
This commit is contained in:
@@ -24,4 +24,4 @@ The abstract from the paper is:
|
||||
|
||||
## PriorTransformerOutput
|
||||
|
||||
[[autodoc]] models.prior_transformer.PriorTransformerOutput
|
||||
[[autodoc]] models.transformers.prior_transformer.PriorTransformerOutput
|
||||
|
||||
@@ -38,4 +38,4 @@ It is assumed one of the input classes is the masked latent pixel. The predicted
|
||||
|
||||
## Transformer2DModelOutput
|
||||
|
||||
[[autodoc]] models.transformer_2d.Transformer2DModelOutput
|
||||
[[autodoc]] models.transformers.transformer_2d.Transformer2DModelOutput
|
||||
|
||||
@@ -16,8 +16,8 @@ A Transformer model for video-like data.
|
||||
|
||||
## TransformerTemporalModel
|
||||
|
||||
[[autodoc]] models.transformer_temporal.TransformerTemporalModel
|
||||
[[autodoc]] models.transformers.transformer_temporal.TransformerTemporalModel
|
||||
|
||||
## TransformerTemporalModelOutput
|
||||
|
||||
[[autodoc]] models.transformer_temporal.TransformerTemporalModelOutput
|
||||
[[autodoc]] models.transformers.transformer_temporal.TransformerTemporalModelOutput
|
||||
|
||||
@@ -6,7 +6,7 @@ from accelerate import load_checkpoint_and_dispatch
|
||||
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from diffusers import UnCLIPPipeline, UNet2DConditionModel, UNet2DModel
|
||||
from diffusers.models.prior_transformer import PriorTransformer
|
||||
from diffusers.models.transformers.prior_transformer import PriorTransformer
|
||||
from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
|
||||
from diffusers.schedulers.scheduling_unclip import UnCLIPScheduler
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import torch
|
||||
from accelerate import load_checkpoint_and_dispatch
|
||||
|
||||
from diffusers import UNet2DConditionModel
|
||||
from diffusers.models.prior_transformer import PriorTransformer
|
||||
from diffusers.models.transformers.prior_transformer import PriorTransformer
|
||||
from diffusers.models.vq_model import VQModel
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import tempfile
|
||||
import torch
|
||||
from accelerate import load_checkpoint_and_dispatch
|
||||
|
||||
from diffusers.models.prior_transformer import PriorTransformer
|
||||
from diffusers.models.transformers.prior_transformer import PriorTransformer
|
||||
from diffusers.pipelines.shap_e import ShapERenderer
|
||||
|
||||
|
||||
|
||||
@@ -35,10 +35,10 @@ if is_torch_available():
|
||||
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
|
||||
_import_structure["embeddings"] = ["ImageProjection"]
|
||||
_import_structure["modeling_utils"] = ["ModelMixin"]
|
||||
_import_structure["prior_transformer"] = ["PriorTransformer"]
|
||||
_import_structure["t5_film_transformer"] = ["T5FilmDecoder"]
|
||||
_import_structure["transformer_2d"] = ["Transformer2DModel"]
|
||||
_import_structure["transformer_temporal"] = ["TransformerTemporalModel"]
|
||||
_import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
|
||||
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
|
||||
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
|
||||
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
|
||||
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
|
||||
_import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"]
|
||||
@@ -66,13 +66,15 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ConsistencyDecoderVAE,
|
||||
)
|
||||
from .controlnet import ControlNetModel
|
||||
from .dual_transformer_2d import DualTransformer2DModel
|
||||
from .embeddings import ImageProjection
|
||||
from .modeling_utils import ModelMixin
|
||||
from .prior_transformer import PriorTransformer
|
||||
from .t5_film_transformer import T5FilmDecoder
|
||||
from .transformer_2d import Transformer2DModel
|
||||
from .transformer_temporal import TransformerTemporalModel
|
||||
from .transformers import (
|
||||
DualTransformer2DModel,
|
||||
PriorTransformer,
|
||||
T5FilmDecoder,
|
||||
Transformer2DModel,
|
||||
TransformerTemporalModel,
|
||||
)
|
||||
from .unets import (
|
||||
Kandinsky3UNet,
|
||||
MotionAdapter,
|
||||
|
||||
@@ -11,145 +11,10 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional
|
||||
|
||||
from torch import nn
|
||||
|
||||
from .transformer_2d import Transformer2DModel, Transformer2DModelOutput
|
||||
from ..utils import deprecate
|
||||
from .transformers.dual_transformer_2d import DualTransformer2DModel
|
||||
|
||||
|
||||
class DualTransformer2DModel(nn.Module):
|
||||
"""
|
||||
Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
||||
in_channels (`int`, *optional*):
|
||||
Pass if the input is continuous. The number of channels in the input and output.
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
|
||||
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
|
||||
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
|
||||
`ImagePositionalEmbeddings`.
|
||||
num_vector_embeds (`int`, *optional*):
|
||||
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
|
||||
Includes the class for the masked latent pixel.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
|
||||
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
|
||||
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
|
||||
up to but not more than steps than `num_embeds_ada_norm`.
|
||||
attention_bias (`bool`, *optional*):
|
||||
Configure if the TransformerBlocks' attention should contain a bias parameter.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 88,
|
||||
in_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
dropout: float = 0.0,
|
||||
norm_num_groups: int = 32,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
sample_size: Optional[int] = None,
|
||||
num_vector_embeds: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.transformers = nn.ModuleList(
|
||||
[
|
||||
Transformer2DModel(
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
in_channels=in_channels,
|
||||
num_layers=num_layers,
|
||||
dropout=dropout,
|
||||
norm_num_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attention_bias=attention_bias,
|
||||
sample_size=sample_size,
|
||||
num_vector_embeds=num_vector_embeds,
|
||||
activation_fn=activation_fn,
|
||||
num_embeds_ada_norm=num_embeds_ada_norm,
|
||||
)
|
||||
for _ in range(2)
|
||||
]
|
||||
)
|
||||
|
||||
# Variables that can be set by a pipeline:
|
||||
|
||||
# The ratio of transformer1 to transformer2's output states to be combined during inference
|
||||
self.mix_ratio = 0.5
|
||||
|
||||
# The shape of `encoder_hidden_states` is expected to be
|
||||
# `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
|
||||
self.condition_lengths = [77, 257]
|
||||
|
||||
# Which transformer to use to encode which condition.
|
||||
# E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
|
||||
self.transformer_index_for_condition = [1, 0]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
timestep=None,
|
||||
attention_mask=None,
|
||||
cross_attention_kwargs=None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
|
||||
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
|
||||
hidden_states.
|
||||
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
timestep ( `torch.long`, *optional*):
|
||||
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
|
||||
attention_mask (`torch.FloatTensor`, *optional*):
|
||||
Optional attention mask to be applied in Attention.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
|
||||
[`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
input_states = hidden_states
|
||||
|
||||
encoded_states = []
|
||||
tokens_start = 0
|
||||
# attention_mask is not used yet
|
||||
for i in range(2):
|
||||
# for each of the two transformers, pass the corresponding condition tokens
|
||||
condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
|
||||
transformer_index = self.transformer_index_for_condition[i]
|
||||
encoded_state = self.transformers[transformer_index](
|
||||
input_states,
|
||||
encoder_hidden_states=condition_state,
|
||||
timestep=timestep,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
encoded_states.append(encoded_state - input_states)
|
||||
tokens_start += self.condition_lengths[i]
|
||||
|
||||
output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
|
||||
output_states = output_states + input_states
|
||||
|
||||
if not return_dict:
|
||||
return (output_states,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output_states)
|
||||
class DualTransformer2DModel(DualTransformer2DModel):
|
||||
deprecation_message = "Importing `DualTransformer2DModel` from `diffusers.models.dual_transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.dual_transformer_2d import DualTransformer2DModel`, instead."
|
||||
deprecate("DualTransformer2DModel", "0.29", deprecation_message)
|
||||
|
||||
@@ -1,380 +1,12 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Union
|
||||
from ..utils import deprecate
|
||||
from .transformers.prior_transformer import PriorTransformer, PriorTransformerOutput
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
|
||||
from ..utils import BaseOutput
|
||||
from .attention import BasicTransformerBlock
|
||||
from .attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
from .embeddings import TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
class PriorTransformerOutput(PriorTransformerOutput):
|
||||
deprecation_message = "Importing `PriorTransformerOutput` from `diffusers.models.prior_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.prior_transformer import PriorTransformerOutput`, instead."
|
||||
deprecate("PriorTransformerOutput", "0.29", deprecation_message)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PriorTransformerOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`PriorTransformer`].
|
||||
|
||||
Args:
|
||||
predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
|
||||
The predicted CLIP image embedding conditioned on the CLIP text embedding input.
|
||||
"""
|
||||
|
||||
predicted_image_embedding: torch.FloatTensor
|
||||
|
||||
|
||||
class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
|
||||
"""
|
||||
A Prior Transformer model.
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
||||
num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
|
||||
embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states`
|
||||
num_embeddings (`int`, *optional*, defaults to 77):
|
||||
The number of embeddings of the model input `hidden_states`
|
||||
additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
|
||||
projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings +
|
||||
additional_embeddings`.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
time_embed_act_fn (`str`, *optional*, defaults to 'silu'):
|
||||
The activation function to use to create timestep embeddings.
|
||||
norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before
|
||||
passing to Transformer blocks. Set it to `None` if normalization is not needed.
|
||||
embedding_proj_norm_type (`str`, *optional*, defaults to None):
|
||||
The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not
|
||||
needed.
|
||||
encoder_hid_proj_type (`str`, *optional*, defaults to `linear`):
|
||||
The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if
|
||||
`encoder_hidden_states` is `None`.
|
||||
added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model.
|
||||
Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot
|
||||
product between the text embedding and image embedding as proposed in the unclip paper
|
||||
https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended.
|
||||
time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings.
|
||||
If None, will be set to `num_attention_heads * attention_head_dim`
|
||||
embedding_proj_dim (`int`, *optional*, default to None):
|
||||
The dimension of `proj_embedding`. If None, will be set to `embedding_dim`.
|
||||
clip_embed_dim (`int`, *optional*, default to None):
|
||||
The dimension of the output. If None, will be set to `embedding_dim`.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 32,
|
||||
attention_head_dim: int = 64,
|
||||
num_layers: int = 20,
|
||||
embedding_dim: int = 768,
|
||||
num_embeddings=77,
|
||||
additional_embeddings=4,
|
||||
dropout: float = 0.0,
|
||||
time_embed_act_fn: str = "silu",
|
||||
norm_in_type: Optional[str] = None, # layer
|
||||
embedding_proj_norm_type: Optional[str] = None, # layer
|
||||
encoder_hid_proj_type: Optional[str] = "linear", # linear
|
||||
added_emb_type: Optional[str] = "prd", # prd
|
||||
time_embed_dim: Optional[int] = None,
|
||||
embedding_proj_dim: Optional[int] = None,
|
||||
clip_embed_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
self.additional_embeddings = additional_embeddings
|
||||
|
||||
time_embed_dim = time_embed_dim or inner_dim
|
||||
embedding_proj_dim = embedding_proj_dim or embedding_dim
|
||||
clip_embed_dim = clip_embed_dim or embedding_dim
|
||||
|
||||
self.time_proj = Timesteps(inner_dim, True, 0)
|
||||
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn)
|
||||
|
||||
self.proj_in = nn.Linear(embedding_dim, inner_dim)
|
||||
|
||||
if embedding_proj_norm_type is None:
|
||||
self.embedding_proj_norm = None
|
||||
elif embedding_proj_norm_type == "layer":
|
||||
self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim)
|
||||
else:
|
||||
raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}")
|
||||
|
||||
self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim)
|
||||
|
||||
if encoder_hid_proj_type is None:
|
||||
self.encoder_hidden_states_proj = None
|
||||
elif encoder_hid_proj_type == "linear":
|
||||
self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
|
||||
else:
|
||||
raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}")
|
||||
|
||||
self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim))
|
||||
|
||||
if added_emb_type == "prd":
|
||||
self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
|
||||
elif added_emb_type is None:
|
||||
self.prd_embedding = None
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`."
|
||||
)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
dropout=dropout,
|
||||
activation_fn="gelu",
|
||||
attention_bias=True,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
if norm_in_type == "layer":
|
||||
self.norm_in = nn.LayerNorm(inner_dim)
|
||||
elif norm_in_type is None:
|
||||
self.norm_in = None
|
||||
else:
|
||||
raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.")
|
||||
|
||||
self.norm_out = nn.LayerNorm(inner_dim)
|
||||
|
||||
self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim)
|
||||
|
||||
causal_attention_mask = torch.full(
|
||||
[num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0
|
||||
)
|
||||
causal_attention_mask.triu_(1)
|
||||
causal_attention_mask = causal_attention_mask[None, ...]
|
||||
self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False)
|
||||
|
||||
self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim))
|
||||
self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim))
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
"""
|
||||
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnAddedKVProcessor()
|
||||
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnProcessor()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
proj_embedding: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.BoolTensor] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
The [`PriorTransformer`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
|
||||
The currently predicted image embeddings.
|
||||
timestep (`torch.LongTensor`):
|
||||
Current denoising step.
|
||||
proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
|
||||
Projected embedding vector the denoising process is conditioned on.
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
|
||||
Hidden states of the text embeddings the denoising process is conditioned on.
|
||||
attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
|
||||
Text mask for the text embeddings.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.prior_transformer.PriorTransformerOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.prior_transformer.PriorTransformerOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.prior_transformer.PriorTransformerOutput`] is returned, otherwise a
|
||||
tuple is returned where the first element is the sample tensor.
|
||||
"""
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(hidden_states.device)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device)
|
||||
|
||||
timesteps_projected = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might be fp16, so we need to cast here.
|
||||
timesteps_projected = timesteps_projected.to(dtype=self.dtype)
|
||||
time_embeddings = self.time_embedding(timesteps_projected)
|
||||
|
||||
if self.embedding_proj_norm is not None:
|
||||
proj_embedding = self.embedding_proj_norm(proj_embedding)
|
||||
|
||||
proj_embeddings = self.embedding_proj(proj_embedding)
|
||||
if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None:
|
||||
encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
|
||||
elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None:
|
||||
raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set")
|
||||
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
positional_embeddings = self.positional_embedding.to(hidden_states.dtype)
|
||||
|
||||
additional_embeds = []
|
||||
additional_embeddings_len = 0
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
additional_embeds.append(encoder_hidden_states)
|
||||
additional_embeddings_len += encoder_hidden_states.shape[1]
|
||||
|
||||
if len(proj_embeddings.shape) == 2:
|
||||
proj_embeddings = proj_embeddings[:, None, :]
|
||||
|
||||
if len(hidden_states.shape) == 2:
|
||||
hidden_states = hidden_states[:, None, :]
|
||||
|
||||
additional_embeds = additional_embeds + [
|
||||
proj_embeddings,
|
||||
time_embeddings[:, None, :],
|
||||
hidden_states,
|
||||
]
|
||||
|
||||
if self.prd_embedding is not None:
|
||||
prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
|
||||
additional_embeds.append(prd_embedding)
|
||||
|
||||
hidden_states = torch.cat(
|
||||
additional_embeds,
|
||||
dim=1,
|
||||
)
|
||||
|
||||
# Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens
|
||||
additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1
|
||||
if positional_embeddings.shape[1] < hidden_states.shape[1]:
|
||||
positional_embeddings = F.pad(
|
||||
positional_embeddings,
|
||||
(
|
||||
0,
|
||||
0,
|
||||
additional_embeddings_len,
|
||||
self.prd_embedding.shape[1] if self.prd_embedding is not None else 0,
|
||||
),
|
||||
value=0.0,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + positional_embeddings
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
|
||||
attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
|
||||
attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
|
||||
|
||||
if self.norm_in is not None:
|
||||
hidden_states = self.norm_in(hidden_states)
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(hidden_states, attention_mask=attention_mask)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
|
||||
if self.prd_embedding is not None:
|
||||
hidden_states = hidden_states[:, -1]
|
||||
else:
|
||||
hidden_states = hidden_states[:, additional_embeddings_len:]
|
||||
|
||||
predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
return (predicted_image_embedding,)
|
||||
|
||||
return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding)
|
||||
|
||||
def post_process_latents(self, prior_latents):
|
||||
prior_latents = (prior_latents * self.clip_std) + self.clip_mean
|
||||
return prior_latents
|
||||
class PriorTransformer(PriorTransformer):
|
||||
deprecation_message = "Importing `PriorTransformer` from `diffusers.models.prior_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.prior_transformer import PriorTransformer`, instead."
|
||||
deprecate("PriorTransformer", "0.29", deprecation_message)
|
||||
|
||||
@@ -11,428 +11,60 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
from ..utils import deprecate
|
||||
from .transformers.t5_film_transformer import (
|
||||
DecoderLayer,
|
||||
NewGELUActivation,
|
||||
T5DenseGatedActDense,
|
||||
T5FilmDecoder,
|
||||
T5FiLMLayer,
|
||||
T5LayerCrossAttention,
|
||||
T5LayerFFCond,
|
||||
T5LayerNorm,
|
||||
T5LayerSelfAttentionCond,
|
||||
)
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .attention_processor import Attention
|
||||
from .embeddings import get_timestep_embedding
|
||||
from .modeling_utils import ModelMixin
|
||||
class T5FilmDecoder(T5FilmDecoder):
|
||||
deprecation_message = "Importing `T5FilmDecoder` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import T5FilmDecoder`, instead."
|
||||
deprecate("T5FilmDecoder", "0.29", deprecation_message)
|
||||
|
||||
|
||||
class T5FilmDecoder(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
T5 style decoder with FiLM conditioning.
|
||||
class DecoderLayer(DecoderLayer):
|
||||
deprecation_message = "Importing `DecoderLayer` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import DecoderLayer`, instead."
|
||||
deprecate("DecoderLayer", "0.29", deprecation_message)
|
||||
|
||||
Args:
|
||||
input_dims (`int`, *optional*, defaults to `128`):
|
||||
The number of input dimensions.
|
||||
targets_length (`int`, *optional*, defaults to `256`):
|
||||
The length of the targets.
|
||||
d_model (`int`, *optional*, defaults to `768`):
|
||||
Size of the input hidden states.
|
||||
num_layers (`int`, *optional*, defaults to `12`):
|
||||
The number of `DecoderLayer`'s to use.
|
||||
num_heads (`int`, *optional*, defaults to `12`):
|
||||
The number of attention heads to use.
|
||||
d_kv (`int`, *optional*, defaults to `64`):
|
||||
Size of the key-value projection vectors.
|
||||
d_ff (`int`, *optional*, defaults to `2048`):
|
||||
The number of dimensions in the intermediate feed-forward layer of `DecoderLayer`'s.
|
||||
dropout_rate (`float`, *optional*, defaults to `0.1`):
|
||||
Dropout probability.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
input_dims: int = 128,
|
||||
targets_length: int = 256,
|
||||
max_decoder_noise_time: float = 2000.0,
|
||||
d_model: int = 768,
|
||||
num_layers: int = 12,
|
||||
num_heads: int = 12,
|
||||
d_kv: int = 64,
|
||||
d_ff: int = 2048,
|
||||
dropout_rate: float = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
class T5LayerSelfAttentionCond(T5LayerSelfAttentionCond):
|
||||
deprecation_message = "Importing `T5LayerSelfAttentionCond` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import T5LayerSelfAttentionCond`, instead."
|
||||
deprecate("T5LayerSelfAttentionCond", "0.29", deprecation_message)
|
||||
|
||||
self.conditioning_emb = nn.Sequential(
|
||||
nn.Linear(d_model, d_model * 4, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(d_model * 4, d_model * 4, bias=False),
|
||||
nn.SiLU(),
|
||||
)
|
||||
|
||||
self.position_encoding = nn.Embedding(targets_length, d_model)
|
||||
self.position_encoding.weight.requires_grad = False
|
||||
class T5LayerCrossAttention(T5LayerCrossAttention):
|
||||
deprecation_message = "Importing `T5LayerCrossAttention` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import T5LayerCrossAttention`, instead."
|
||||
deprecate("T5LayerCrossAttention", "0.29", deprecation_message)
|
||||
|
||||
self.continuous_inputs_projection = nn.Linear(input_dims, d_model, bias=False)
|
||||
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
class T5LayerFFCond(T5LayerFFCond):
|
||||
deprecation_message = "Importing `T5LayerFFCond` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import T5LayerFFCond`, instead."
|
||||
deprecate("T5LayerFFCond", "0.29", deprecation_message)
|
||||
|
||||
self.decoders = nn.ModuleList()
|
||||
for lyr_num in range(num_layers):
|
||||
# FiLM conditional T5 decoder
|
||||
lyr = DecoderLayer(d_model=d_model, d_kv=d_kv, num_heads=num_heads, d_ff=d_ff, dropout_rate=dropout_rate)
|
||||
self.decoders.append(lyr)
|
||||
|
||||
self.decoder_norm = T5LayerNorm(d_model)
|
||||
class T5DenseGatedActDense(T5DenseGatedActDense):
|
||||
deprecation_message = "Importing `T5DenseGatedActDense` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import T5DenseGatedActDense`, instead."
|
||||
deprecate("T5DenseGatedActDense", "0.29", deprecation_message)
|
||||
|
||||
self.post_dropout = nn.Dropout(p=dropout_rate)
|
||||
self.spec_out = nn.Linear(d_model, input_dims, bias=False)
|
||||
|
||||
def encoder_decoder_mask(self, query_input: torch.FloatTensor, key_input: torch.FloatTensor) -> torch.FloatTensor:
|
||||
mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2))
|
||||
return mask.unsqueeze(-3)
|
||||
class T5LayerNorm(T5LayerNorm):
|
||||
deprecation_message = "Importing `T5LayerNorm` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import T5LayerNorm`, instead."
|
||||
deprecate("T5LayerNorm", "0.29", deprecation_message)
|
||||
|
||||
def forward(self, encodings_and_masks, decoder_input_tokens, decoder_noise_time):
|
||||
batch, _, _ = decoder_input_tokens.shape
|
||||
assert decoder_noise_time.shape == (batch,)
|
||||
|
||||
# decoder_noise_time is in [0, 1), so rescale to expected timing range.
|
||||
time_steps = get_timestep_embedding(
|
||||
decoder_noise_time * self.config.max_decoder_noise_time,
|
||||
embedding_dim=self.config.d_model,
|
||||
max_period=self.config.max_decoder_noise_time,
|
||||
).to(dtype=self.dtype)
|
||||
class NewGELUActivation(NewGELUActivation):
|
||||
deprecation_message = "Importing `T5LayerNorm` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import NewGELUActivation`, instead."
|
||||
deprecate("NewGELUActivation", "0.29", deprecation_message)
|
||||
|
||||
conditioning_emb = self.conditioning_emb(time_steps).unsqueeze(1)
|
||||
|
||||
assert conditioning_emb.shape == (batch, 1, self.config.d_model * 4)
|
||||
|
||||
seq_length = decoder_input_tokens.shape[1]
|
||||
|
||||
# If we want to use relative positions for audio context, we can just offset
|
||||
# this sequence by the length of encodings_and_masks.
|
||||
decoder_positions = torch.broadcast_to(
|
||||
torch.arange(seq_length, device=decoder_input_tokens.device),
|
||||
(batch, seq_length),
|
||||
)
|
||||
|
||||
position_encodings = self.position_encoding(decoder_positions)
|
||||
|
||||
inputs = self.continuous_inputs_projection(decoder_input_tokens)
|
||||
inputs += position_encodings
|
||||
y = self.dropout(inputs)
|
||||
|
||||
# decoder: No padding present.
|
||||
decoder_mask = torch.ones(
|
||||
decoder_input_tokens.shape[:2], device=decoder_input_tokens.device, dtype=inputs.dtype
|
||||
)
|
||||
|
||||
# Translate encoding masks to encoder-decoder masks.
|
||||
encodings_and_encdec_masks = [(x, self.encoder_decoder_mask(decoder_mask, y)) for x, y in encodings_and_masks]
|
||||
|
||||
# cross attend style: concat encodings
|
||||
encoded = torch.cat([x[0] for x in encodings_and_encdec_masks], dim=1)
|
||||
encoder_decoder_mask = torch.cat([x[1] for x in encodings_and_encdec_masks], dim=-1)
|
||||
|
||||
for lyr in self.decoders:
|
||||
y = lyr(
|
||||
y,
|
||||
conditioning_emb=conditioning_emb,
|
||||
encoder_hidden_states=encoded,
|
||||
encoder_attention_mask=encoder_decoder_mask,
|
||||
)[0]
|
||||
|
||||
y = self.decoder_norm(y)
|
||||
y = self.post_dropout(y)
|
||||
|
||||
spec_out = self.spec_out(y)
|
||||
return spec_out
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
r"""
|
||||
T5 decoder layer.
|
||||
|
||||
Args:
|
||||
d_model (`int`):
|
||||
Size of the input hidden states.
|
||||
d_kv (`int`):
|
||||
Size of the key-value projection vectors.
|
||||
num_heads (`int`):
|
||||
Number of attention heads.
|
||||
d_ff (`int`):
|
||||
Size of the intermediate feed-forward layer.
|
||||
dropout_rate (`float`):
|
||||
Dropout probability.
|
||||
layer_norm_epsilon (`float`, *optional*, defaults to `1e-6`):
|
||||
A small value used for numerical stability to avoid dividing by zero.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, d_model: int, d_kv: int, num_heads: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float = 1e-6
|
||||
):
|
||||
super().__init__()
|
||||
self.layer = nn.ModuleList()
|
||||
|
||||
# cond self attention: layer 0
|
||||
self.layer.append(
|
||||
T5LayerSelfAttentionCond(d_model=d_model, d_kv=d_kv, num_heads=num_heads, dropout_rate=dropout_rate)
|
||||
)
|
||||
|
||||
# cross attention: layer 1
|
||||
self.layer.append(
|
||||
T5LayerCrossAttention(
|
||||
d_model=d_model,
|
||||
d_kv=d_kv,
|
||||
num_heads=num_heads,
|
||||
dropout_rate=dropout_rate,
|
||||
layer_norm_epsilon=layer_norm_epsilon,
|
||||
)
|
||||
)
|
||||
|
||||
# Film Cond MLP + dropout: last layer
|
||||
self.layer.append(
|
||||
T5LayerFFCond(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate, layer_norm_epsilon=layer_norm_epsilon)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
conditioning_emb: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_decoder_position_bias=None,
|
||||
) -> Tuple[torch.FloatTensor]:
|
||||
hidden_states = self.layer[0](
|
||||
hidden_states,
|
||||
conditioning_emb=conditioning_emb,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_extended_attention_mask = torch.where(encoder_attention_mask > 0, 0, -1e10).to(
|
||||
encoder_hidden_states.dtype
|
||||
)
|
||||
|
||||
hidden_states = self.layer[1](
|
||||
hidden_states,
|
||||
key_value_states=encoder_hidden_states,
|
||||
attention_mask=encoder_extended_attention_mask,
|
||||
)
|
||||
|
||||
# Apply Film Conditional Feed Forward layer
|
||||
hidden_states = self.layer[-1](hidden_states, conditioning_emb)
|
||||
|
||||
return (hidden_states,)
|
||||
|
||||
|
||||
class T5LayerSelfAttentionCond(nn.Module):
|
||||
r"""
|
||||
T5 style self-attention layer with conditioning.
|
||||
|
||||
Args:
|
||||
d_model (`int`):
|
||||
Size of the input hidden states.
|
||||
d_kv (`int`):
|
||||
Size of the key-value projection vectors.
|
||||
num_heads (`int`):
|
||||
Number of attention heads.
|
||||
dropout_rate (`float`):
|
||||
Dropout probability.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float):
|
||||
super().__init__()
|
||||
self.layer_norm = T5LayerNorm(d_model)
|
||||
self.FiLMLayer = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
|
||||
self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
conditioning_emb: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
# pre_self_attention_layer_norm
|
||||
normed_hidden_states = self.layer_norm(hidden_states)
|
||||
|
||||
if conditioning_emb is not None:
|
||||
normed_hidden_states = self.FiLMLayer(normed_hidden_states, conditioning_emb)
|
||||
|
||||
# Self-attention block
|
||||
attention_output = self.attention(normed_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + self.dropout(attention_output)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class T5LayerCrossAttention(nn.Module):
|
||||
r"""
|
||||
T5 style cross-attention layer.
|
||||
|
||||
Args:
|
||||
d_model (`int`):
|
||||
Size of the input hidden states.
|
||||
d_kv (`int`):
|
||||
Size of the key-value projection vectors.
|
||||
num_heads (`int`):
|
||||
Number of attention heads.
|
||||
dropout_rate (`float`):
|
||||
Dropout probability.
|
||||
layer_norm_epsilon (`float`):
|
||||
A small value used for numerical stability to avoid dividing by zero.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float, layer_norm_epsilon: float):
|
||||
super().__init__()
|
||||
self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
|
||||
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
key_value_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
normed_hidden_states = self.layer_norm(hidden_states)
|
||||
attention_output = self.attention(
|
||||
normed_hidden_states,
|
||||
encoder_hidden_states=key_value_states,
|
||||
attention_mask=attention_mask.squeeze(1),
|
||||
)
|
||||
layer_output = hidden_states + self.dropout(attention_output)
|
||||
return layer_output
|
||||
|
||||
|
||||
class T5LayerFFCond(nn.Module):
|
||||
r"""
|
||||
T5 style feed-forward conditional layer.
|
||||
|
||||
Args:
|
||||
d_model (`int`):
|
||||
Size of the input hidden states.
|
||||
d_ff (`int`):
|
||||
Size of the intermediate feed-forward layer.
|
||||
dropout_rate (`float`):
|
||||
Dropout probability.
|
||||
layer_norm_epsilon (`float`):
|
||||
A small value used for numerical stability to avoid dividing by zero.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float):
|
||||
super().__init__()
|
||||
self.DenseReluDense = T5DenseGatedActDense(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate)
|
||||
self.film = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
|
||||
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, conditioning_emb: Optional[torch.FloatTensor] = None
|
||||
) -> torch.FloatTensor:
|
||||
forwarded_states = self.layer_norm(hidden_states)
|
||||
if conditioning_emb is not None:
|
||||
forwarded_states = self.film(forwarded_states, conditioning_emb)
|
||||
|
||||
forwarded_states = self.DenseReluDense(forwarded_states)
|
||||
hidden_states = hidden_states + self.dropout(forwarded_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class T5DenseGatedActDense(nn.Module):
|
||||
r"""
|
||||
T5 style feed-forward layer with gated activations and dropout.
|
||||
|
||||
Args:
|
||||
d_model (`int`):
|
||||
Size of the input hidden states.
|
||||
d_ff (`int`):
|
||||
Size of the intermediate feed-forward layer.
|
||||
dropout_rate (`float`):
|
||||
Dropout probability.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, d_ff: int, dropout_rate: float):
|
||||
super().__init__()
|
||||
self.wi_0 = nn.Linear(d_model, d_ff, bias=False)
|
||||
self.wi_1 = nn.Linear(d_model, d_ff, bias=False)
|
||||
self.wo = nn.Linear(d_ff, d_model, bias=False)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.act = NewGELUActivation()
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
hidden_gelu = self.act(self.wi_0(hidden_states))
|
||||
hidden_linear = self.wi_1(hidden_states)
|
||||
hidden_states = hidden_gelu * hidden_linear
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
hidden_states = self.wo(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class T5LayerNorm(nn.Module):
|
||||
r"""
|
||||
T5 style layer normalization module.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`):
|
||||
Size of the input hidden states.
|
||||
eps (`float`, `optional`, defaults to `1e-6`):
|
||||
A small value used for numerical stability to avoid dividing by zero.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size: int, eps: float = 1e-6):
|
||||
"""
|
||||
Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
|
||||
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
|
||||
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
|
||||
# half-precision inputs is done in fp32
|
||||
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
|
||||
# convert into half-precision if necessary
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
|
||||
return self.weight * hidden_states
|
||||
|
||||
|
||||
class NewGELUActivation(nn.Module):
|
||||
"""
|
||||
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
|
||||
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
|
||||
"""
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
|
||||
|
||||
|
||||
class T5FiLMLayer(nn.Module):
|
||||
"""
|
||||
T5 style FiLM Layer.
|
||||
|
||||
Args:
|
||||
in_features (`int`):
|
||||
Number of input features.
|
||||
out_features (`int`):
|
||||
Number of output features.
|
||||
"""
|
||||
|
||||
def __init__(self, in_features: int, out_features: int):
|
||||
super().__init__()
|
||||
self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False)
|
||||
|
||||
def forward(self, x: torch.FloatTensor, conditioning_emb: torch.FloatTensor) -> torch.FloatTensor:
|
||||
emb = self.scale_bias(conditioning_emb)
|
||||
scale, shift = torch.chunk(emb, 2, -1)
|
||||
x = x * (1 + scale) + shift
|
||||
return x
|
||||
class T5FiLMLayer(T5FiLMLayer):
|
||||
deprecation_message = "Importing `T5FiLMLayer` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import T5FiLMLayer`, instead."
|
||||
deprecate("T5FiLMLayer", "0.29", deprecation_message)
|
||||
|
||||
@@ -11,449 +11,15 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..models.embeddings import ImagePositionalEmbeddings
|
||||
from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
|
||||
from .attention import BasicTransformerBlock
|
||||
from .embeddings import PatchEmbed, PixArtAlphaTextProjection
|
||||
from .lora import LoRACompatibleConv, LoRACompatibleLinear
|
||||
from .modeling_utils import ModelMixin
|
||||
from .normalization import AdaLayerNormSingle
|
||||
from ..utils import deprecate
|
||||
from .transformers.transformer_2d import Transformer2DModel, Transformer2DModelOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class Transformer2DModelOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`Transformer2DModel`].
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
|
||||
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
|
||||
distributions for the unnoised latent pixels.
|
||||
"""
|
||||
|
||||
sample: torch.FloatTensor
|
||||
class Transformer2DModelOutput(Transformer2DModelOutput):
|
||||
deprecation_message = "Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.transformer_2d import Transformer2DModelOutput`, instead."
|
||||
deprecate("Transformer2DModelOutput", "0.29", deprecation_message)
|
||||
|
||||
|
||||
class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
A 2D Transformer model for image-like data.
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
||||
in_channels (`int`, *optional*):
|
||||
The number of channels in the input and output (specify if the input is **continuous**).
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
||||
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
||||
This is fixed during training since it is used to learn a number of position embeddings.
|
||||
num_vector_embeds (`int`, *optional*):
|
||||
The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
|
||||
Includes the class for the masked latent pixel.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
|
||||
num_embeds_ada_norm ( `int`, *optional*):
|
||||
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
|
||||
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
|
||||
added to the hidden states.
|
||||
|
||||
During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
|
||||
attention_bias (`bool`, *optional*):
|
||||
Configure if the `TransformerBlocks` attention should contain a bias parameter.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 88,
|
||||
in_channels: Optional[int] = None,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
dropout: float = 0.0,
|
||||
norm_num_groups: int = 32,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
sample_size: Optional[int] = None,
|
||||
num_vector_embeds: Optional[int] = None,
|
||||
patch_size: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
use_linear_projection: bool = False,
|
||||
only_cross_attention: bool = False,
|
||||
double_self_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
norm_type: str = "layer_norm",
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_eps: float = 1e-5,
|
||||
attention_type: str = "default",
|
||||
caption_channels: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_linear_projection = use_linear_projection
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
||||
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
||||
|
||||
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
|
||||
# Define whether input is continuous or discrete depending on configuration
|
||||
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
|
||||
self.is_input_vectorized = num_vector_embeds is not None
|
||||
self.is_input_patches = in_channels is not None and patch_size is not None
|
||||
|
||||
if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
|
||||
" incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
|
||||
" Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
|
||||
" results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
|
||||
" would be very nice if you could open a Pull request for the `transformer/config.json` file"
|
||||
)
|
||||
deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
|
||||
norm_type = "ada_norm"
|
||||
|
||||
if self.is_input_continuous and self.is_input_vectorized:
|
||||
raise ValueError(
|
||||
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
|
||||
" sure that either `in_channels` or `num_vector_embeds` is None."
|
||||
)
|
||||
elif self.is_input_vectorized and self.is_input_patches:
|
||||
raise ValueError(
|
||||
f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
|
||||
" sure that either `num_vector_embeds` or `num_patches` is None."
|
||||
)
|
||||
elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
|
||||
raise ValueError(
|
||||
f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
|
||||
f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
|
||||
)
|
||||
|
||||
# 2. Define input layers
|
||||
if self.is_input_continuous:
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
if use_linear_projection:
|
||||
self.proj_in = linear_cls(in_channels, inner_dim)
|
||||
else:
|
||||
self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
||||
elif self.is_input_vectorized:
|
||||
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
|
||||
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
|
||||
|
||||
self.height = sample_size
|
||||
self.width = sample_size
|
||||
self.num_vector_embeds = num_vector_embeds
|
||||
self.num_latent_pixels = self.height * self.width
|
||||
|
||||
self.latent_image_embedding = ImagePositionalEmbeddings(
|
||||
num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
|
||||
)
|
||||
elif self.is_input_patches:
|
||||
assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
|
||||
|
||||
self.height = sample_size
|
||||
self.width = sample_size
|
||||
|
||||
self.patch_size = patch_size
|
||||
interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
|
||||
interpolation_scale = max(interpolation_scale, 1)
|
||||
self.pos_embed = PatchEmbed(
|
||||
height=sample_size,
|
||||
width=sample_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dim=inner_dim,
|
||||
interpolation_scale=interpolation_scale,
|
||||
)
|
||||
|
||||
# 3. Define transformers blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
dropout=dropout,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
activation_fn=activation_fn,
|
||||
num_embeds_ada_norm=num_embeds_ada_norm,
|
||||
attention_bias=attention_bias,
|
||||
only_cross_attention=only_cross_attention,
|
||||
double_self_attention=double_self_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
norm_type=norm_type,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
norm_eps=norm_eps,
|
||||
attention_type=attention_type,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 4. Define output layers
|
||||
self.out_channels = in_channels if out_channels is None else out_channels
|
||||
if self.is_input_continuous:
|
||||
# TODO: should use out_channels for continuous projections
|
||||
if use_linear_projection:
|
||||
self.proj_out = linear_cls(inner_dim, in_channels)
|
||||
else:
|
||||
self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
elif self.is_input_vectorized:
|
||||
self.norm_out = nn.LayerNorm(inner_dim)
|
||||
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
|
||||
elif self.is_input_patches and norm_type != "ada_norm_single":
|
||||
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
|
||||
self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
|
||||
elif self.is_input_patches and norm_type == "ada_norm_single":
|
||||
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
|
||||
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
|
||||
|
||||
# 5. PixArt-Alpha blocks.
|
||||
self.adaln_single = None
|
||||
self.use_additional_conditions = False
|
||||
if norm_type == "ada_norm_single":
|
||||
self.use_additional_conditions = self.config.sample_size == 128
|
||||
# TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
|
||||
# additional conditions until we find better name
|
||||
self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
|
||||
|
||||
self.caption_projection = None
|
||||
if caption_channels is not None:
|
||||
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
||||
class_labels: Optional[torch.LongTensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
The [`Transformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
timestep ( `torch.LongTensor`, *optional*):
|
||||
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
||||
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
||||
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
||||
`AdaLayerZeroNorm`.
|
||||
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
attention_mask ( `torch.Tensor`, *optional*):
|
||||
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
||||
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
||||
negative values to the attention scores corresponding to "discard" tokens.
|
||||
encoder_attention_mask ( `torch.Tensor`, *optional*):
|
||||
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
|
||||
|
||||
* Mask `(batch, sequence_length)` True = keep, False = discard.
|
||||
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
|
||||
|
||||
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
|
||||
above. This bias will be added to the cross-attention scores.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
||||
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
||||
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
||||
# expects mask of shape:
|
||||
# [batch, key_tokens]
|
||||
# adds singleton query_tokens dimension:
|
||||
# [batch, 1, key_tokens]
|
||||
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
||||
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
||||
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
||||
if attention_mask is not None and attention_mask.ndim == 2:
|
||||
# assume that mask is expressed as:
|
||||
# (1 = keep, 0 = discard)
|
||||
# convert mask into a bias that can be added to attention scores:
|
||||
# (keep = +0, discard = -10000.0)
|
||||
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
||||
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
||||
|
||||
# Retrieve lora scale.
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
|
||||
# 1. Input
|
||||
if self.is_input_continuous:
|
||||
batch, _, height, width = hidden_states.shape
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
if not self.use_linear_projection:
|
||||
hidden_states = (
|
||||
self.proj_in(hidden_states, scale=lora_scale)
|
||||
if not USE_PEFT_BACKEND
|
||||
else self.proj_in(hidden_states)
|
||||
)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||
else:
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||
hidden_states = (
|
||||
self.proj_in(hidden_states, scale=lora_scale)
|
||||
if not USE_PEFT_BACKEND
|
||||
else self.proj_in(hidden_states)
|
||||
)
|
||||
|
||||
elif self.is_input_vectorized:
|
||||
hidden_states = self.latent_image_embedding(hidden_states)
|
||||
elif self.is_input_patches:
|
||||
height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
|
||||
hidden_states = self.pos_embed(hidden_states)
|
||||
|
||||
if self.adaln_single is not None:
|
||||
if self.use_additional_conditions and added_cond_kwargs is None:
|
||||
raise ValueError(
|
||||
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
|
||||
)
|
||||
batch_size = hidden_states.shape[0]
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
|
||||
# 2. Blocks
|
||||
if self.caption_projection is not None:
|
||||
batch_size = hidden_states.shape[0]
|
||||
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
timestep,
|
||||
cross_attention_kwargs,
|
||||
class_labels,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
timestep=timestep,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
class_labels=class_labels,
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
if self.is_input_continuous:
|
||||
if not self.use_linear_projection:
|
||||
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
hidden_states = (
|
||||
self.proj_out(hidden_states, scale=lora_scale)
|
||||
if not USE_PEFT_BACKEND
|
||||
else self.proj_out(hidden_states)
|
||||
)
|
||||
else:
|
||||
hidden_states = (
|
||||
self.proj_out(hidden_states, scale=lora_scale)
|
||||
if not USE_PEFT_BACKEND
|
||||
else self.proj_out(hidden_states)
|
||||
)
|
||||
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
output = hidden_states + residual
|
||||
elif self.is_input_vectorized:
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
logits = self.out(hidden_states)
|
||||
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
|
||||
logits = logits.permute(0, 2, 1)
|
||||
|
||||
# log(p(x_0))
|
||||
output = F.log_softmax(logits.double(), dim=1).float()
|
||||
|
||||
if self.is_input_patches:
|
||||
if self.config.norm_type != "ada_norm_single":
|
||||
conditioning = self.transformer_blocks[0].norm1.emb(
|
||||
timestep, class_labels, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
||||
hidden_states = self.proj_out_2(hidden_states)
|
||||
elif self.config.norm_type == "ada_norm_single":
|
||||
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
# Modulation
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
# unpatchify
|
||||
if self.adaln_single is None:
|
||||
height = width = int(hidden_states.shape[1] ** 0.5)
|
||||
hidden_states = hidden_states.reshape(
|
||||
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
|
||||
)
|
||||
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
||||
output = hidden_states.reshape(
|
||||
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
class Transformer2DModel(Transformer2DModel):
|
||||
deprecation_message = "Importing `Transformer2DModel` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.transformer_2d import Transformer2DModel`, instead."
|
||||
deprecate("Transformer2DModel", "0.29", deprecation_message)
|
||||
|
||||
@@ -11,369 +11,24 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from .attention import BasicTransformerBlock, TemporalBasicTransformerBlock
|
||||
from .embeddings import TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
from .resnet import AlphaBlender
|
||||
from ..utils import deprecate
|
||||
from .transformers.transformer_temporal import (
|
||||
TransformerSpatioTemporalModel,
|
||||
TransformerTemporalModel,
|
||||
TransformerTemporalModelOutput,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransformerTemporalModelOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`TransformerTemporalModel`].
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
|
||||
The hidden states output conditioned on `encoder_hidden_states` input.
|
||||
"""
|
||||
|
||||
sample: torch.FloatTensor
|
||||
class TransformerTemporalModelOutput(TransformerTemporalModelOutput):
|
||||
deprecation_message = "Importing `TransformerTemporalModelOutput` from `diffusers.models.transformer_temporal` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.tranformer_temporal import TransformerTemporalModelOutput`, instead."
|
||||
deprecate("TransformerTemporalModelOutput", "0.29", deprecation_message)
|
||||
|
||||
|
||||
class TransformerTemporalModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
A Transformer model for video-like data.
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
||||
in_channels (`int`, *optional*):
|
||||
The number of channels in the input and output (specify if the input is **continuous**).
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
||||
attention_bias (`bool`, *optional*):
|
||||
Configure if the `TransformerBlock` attention should contain a bias parameter.
|
||||
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
||||
This is fixed during training since it is used to learn a number of position embeddings.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`):
|
||||
Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported
|
||||
activation functions.
|
||||
norm_elementwise_affine (`bool`, *optional*):
|
||||
Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
|
||||
double_self_attention (`bool`, *optional*):
|
||||
Configure if each `TransformerBlock` should contain two self-attention layers.
|
||||
positional_embeddings: (`str`, *optional*):
|
||||
The type of positional embeddings to apply to the sequence input before passing use.
|
||||
num_positional_embeddings: (`int`, *optional*):
|
||||
The maximum length of the sequence over which to apply positional embeddings.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 88,
|
||||
in_channels: Optional[int] = None,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
dropout: float = 0.0,
|
||||
norm_num_groups: int = 32,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
sample_size: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
norm_elementwise_affine: bool = True,
|
||||
double_self_attention: bool = True,
|
||||
positional_embeddings: Optional[str] = None,
|
||||
num_positional_embeddings: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
self.proj_in = nn.Linear(in_channels, inner_dim)
|
||||
|
||||
# 3. Define transformers blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
dropout=dropout,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
activation_fn=activation_fn,
|
||||
attention_bias=attention_bias,
|
||||
double_self_attention=double_self_attention,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
positional_embeddings=positional_embeddings,
|
||||
num_positional_embeddings=num_positional_embeddings,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.proj_out = nn.Linear(inner_dim, in_channels)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.LongTensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
class_labels: torch.LongTensor = None,
|
||||
num_frames: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> TransformerTemporalModelOutput:
|
||||
"""
|
||||
The [`TransformerTemporal`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
||||
Input hidden_states.
|
||||
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
timestep ( `torch.LongTensor`, *optional*):
|
||||
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
||||
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
||||
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
||||
`AdaLayerZeroNorm`.
|
||||
num_frames (`int`, *optional*, defaults to 1):
|
||||
The number of frames to be processed per batch. This is used to reshape the hidden states.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
|
||||
If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
|
||||
returned, otherwise a `tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
# 1. Input
|
||||
batch_frames, channel, height, width = hidden_states.shape
|
||||
batch_size = batch_frames // num_frames
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
|
||||
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
# 2. Blocks
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
timestep=timestep,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
class_labels=class_labels,
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = (
|
||||
hidden_states[None, None, :]
|
||||
.reshape(batch_size, height, width, num_frames, channel)
|
||||
.permute(0, 3, 4, 1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
hidden_states = hidden_states.reshape(batch_frames, channel, height, width)
|
||||
|
||||
output = hidden_states + residual
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return TransformerTemporalModelOutput(sample=output)
|
||||
class TransformerTemporalModel(TransformerTemporalModel):
|
||||
deprecation_message = "Importing `TransformerTemporalModel` from `diffusers.models.transformer_temporal` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.tranformer_temporal import TransformerTemporalModel`, instead."
|
||||
deprecate("TransformerTemporalModel", "0.29", deprecation_message)
|
||||
|
||||
|
||||
class TransformerSpatioTemporalModel(nn.Module):
|
||||
"""
|
||||
A Transformer model for video-like data.
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
||||
in_channels (`int`, *optional*):
|
||||
The number of channels in the input and output (specify if the input is **continuous**).
|
||||
out_channels (`int`, *optional*):
|
||||
The number of channels in the output (specify if the input is **continuous**).
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 88,
|
||||
in_channels: int = 320,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
self.inner_dim = inner_dim
|
||||
|
||||
# 2. Define input layers
|
||||
self.in_channels = in_channels
|
||||
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6)
|
||||
self.proj_in = nn.Linear(in_channels, inner_dim)
|
||||
|
||||
# 3. Define transformers blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
time_mix_inner_dim = inner_dim
|
||||
self.temporal_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
TemporalBasicTransformerBlock(
|
||||
inner_dim,
|
||||
time_mix_inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
time_embed_dim = in_channels * 4
|
||||
self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels)
|
||||
self.time_proj = Timesteps(in_channels, True, 0)
|
||||
self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy="learned_with_images")
|
||||
|
||||
# 4. Define output layers
|
||||
self.out_channels = in_channels if out_channels is None else out_channels
|
||||
# TODO: should use out_channels for continuous projections
|
||||
self.proj_out = nn.Linear(inner_dim, in_channels)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
image_only_indicator: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
||||
Input hidden_states.
|
||||
num_frames (`int`):
|
||||
The number of frames to be processed per batch. This is used to reshape the hidden states.
|
||||
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
image_only_indicator (`torch.LongTensor` of shape `(batch size, num_frames)`, *optional*):
|
||||
A tensor indicating whether the input contains only images. 1 indicates that the input contains only
|
||||
images, 0 indicates that the input contains video frames.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
|
||||
If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
|
||||
returned, otherwise a `tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
# 1. Input
|
||||
batch_frames, _, height, width = hidden_states.shape
|
||||
num_frames = image_only_indicator.shape[-1]
|
||||
batch_size = batch_frames // num_frames
|
||||
|
||||
time_context = encoder_hidden_states
|
||||
time_context_first_timestep = time_context[None, :].reshape(
|
||||
batch_size, num_frames, -1, time_context.shape[-1]
|
||||
)[:, 0]
|
||||
time_context = time_context_first_timestep[None, :].broadcast_to(
|
||||
height * width, batch_size, 1, time_context.shape[-1]
|
||||
)
|
||||
time_context = time_context.reshape(height * width * batch_size, 1, time_context.shape[-1])
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
num_frames_emb = torch.arange(num_frames, device=hidden_states.device)
|
||||
num_frames_emb = num_frames_emb.repeat(batch_size, 1)
|
||||
num_frames_emb = num_frames_emb.reshape(-1)
|
||||
t_emb = self.time_proj(num_frames_emb)
|
||||
|
||||
# `Timesteps` does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
||||
|
||||
emb = self.time_pos_embed(t_emb)
|
||||
emb = emb[:, None, :]
|
||||
|
||||
# 2. Blocks
|
||||
for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
block,
|
||||
hidden_states,
|
||||
None,
|
||||
encoder_hidden_states,
|
||||
None,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
|
||||
hidden_states_mix = hidden_states
|
||||
hidden_states_mix = hidden_states_mix + emb
|
||||
|
||||
hidden_states_mix = temporal_block(
|
||||
hidden_states_mix,
|
||||
num_frames=num_frames,
|
||||
encoder_hidden_states=time_context,
|
||||
)
|
||||
hidden_states = self.time_mixer(
|
||||
x_spatial=hidden_states,
|
||||
x_temporal=hidden_states_mix,
|
||||
image_only_indicator=image_only_indicator,
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
output = hidden_states + residual
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return TransformerTemporalModelOutput(sample=output)
|
||||
class TransformerSpatioTemporalModel(TransformerSpatioTemporalModel):
|
||||
deprecation_message = "Importing `TransformerSpatioTemporalModel` from `diffusers.models.transformer_temporal` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.tranformer_temporal import TransformerSpatioTemporalModel`, instead."
|
||||
deprecate("TransformerTemporalModelOutput", "0.29", deprecation_message)
|
||||
|
||||
9
src/diffusers/models/transformers/__init__.py
Normal file
9
src/diffusers/models/transformers/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from ...utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .dual_transformer_2d import DualTransformer2DModel
|
||||
from .prior_transformer import PriorTransformer
|
||||
from .t5_film_transformer import T5FilmDecoder
|
||||
from .transformer_2d import Transformer2DModel
|
||||
from .transformer_temporal import TransformerTemporalModel
|
||||
155
src/diffusers/models/transformers/dual_transformer_2d.py
Normal file
155
src/diffusers/models/transformers/dual_transformer_2d.py
Normal file
@@ -0,0 +1,155 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional
|
||||
|
||||
from torch import nn
|
||||
|
||||
from .transformer_2d import Transformer2DModel, Transformer2DModelOutput
|
||||
|
||||
|
||||
class DualTransformer2DModel(nn.Module):
|
||||
"""
|
||||
Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
||||
in_channels (`int`, *optional*):
|
||||
Pass if the input is continuous. The number of channels in the input and output.
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
|
||||
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
|
||||
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
|
||||
`ImagePositionalEmbeddings`.
|
||||
num_vector_embeds (`int`, *optional*):
|
||||
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
|
||||
Includes the class for the masked latent pixel.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
|
||||
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
|
||||
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
|
||||
up to but not more than steps than `num_embeds_ada_norm`.
|
||||
attention_bias (`bool`, *optional*):
|
||||
Configure if the TransformerBlocks' attention should contain a bias parameter.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 88,
|
||||
in_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
dropout: float = 0.0,
|
||||
norm_num_groups: int = 32,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
sample_size: Optional[int] = None,
|
||||
num_vector_embeds: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.transformers = nn.ModuleList(
|
||||
[
|
||||
Transformer2DModel(
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
in_channels=in_channels,
|
||||
num_layers=num_layers,
|
||||
dropout=dropout,
|
||||
norm_num_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attention_bias=attention_bias,
|
||||
sample_size=sample_size,
|
||||
num_vector_embeds=num_vector_embeds,
|
||||
activation_fn=activation_fn,
|
||||
num_embeds_ada_norm=num_embeds_ada_norm,
|
||||
)
|
||||
for _ in range(2)
|
||||
]
|
||||
)
|
||||
|
||||
# Variables that can be set by a pipeline:
|
||||
|
||||
# The ratio of transformer1 to transformer2's output states to be combined during inference
|
||||
self.mix_ratio = 0.5
|
||||
|
||||
# The shape of `encoder_hidden_states` is expected to be
|
||||
# `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
|
||||
self.condition_lengths = [77, 257]
|
||||
|
||||
# Which transformer to use to encode which condition.
|
||||
# E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
|
||||
self.transformer_index_for_condition = [1, 0]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
timestep=None,
|
||||
attention_mask=None,
|
||||
cross_attention_kwargs=None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
|
||||
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
|
||||
hidden_states.
|
||||
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
timestep ( `torch.long`, *optional*):
|
||||
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
|
||||
attention_mask (`torch.FloatTensor`, *optional*):
|
||||
Optional attention mask to be applied in Attention.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
|
||||
[`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
input_states = hidden_states
|
||||
|
||||
encoded_states = []
|
||||
tokens_start = 0
|
||||
# attention_mask is not used yet
|
||||
for i in range(2):
|
||||
# for each of the two transformers, pass the corresponding condition tokens
|
||||
condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
|
||||
transformer_index = self.transformer_index_for_condition[i]
|
||||
encoded_state = self.transformers[transformer_index](
|
||||
input_states,
|
||||
encoder_hidden_states=condition_state,
|
||||
timestep=timestep,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
encoded_states.append(encoded_state - input_states)
|
||||
tokens_start += self.condition_lengths[i]
|
||||
|
||||
output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
|
||||
output_states = output_states + input_states
|
||||
|
||||
if not return_dict:
|
||||
return (output_states,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output_states)
|
||||
380
src/diffusers/models/transformers/prior_transformer.py
Normal file
380
src/diffusers/models/transformers/prior_transformer.py
Normal file
@@ -0,0 +1,380 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
|
||||
from ...utils import BaseOutput
|
||||
from ..attention import BasicTransformerBlock
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
|
||||
|
||||
@dataclass
|
||||
class PriorTransformerOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`PriorTransformer`].
|
||||
|
||||
Args:
|
||||
predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
|
||||
The predicted CLIP image embedding conditioned on the CLIP text embedding input.
|
||||
"""
|
||||
|
||||
predicted_image_embedding: torch.FloatTensor
|
||||
|
||||
|
||||
class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
|
||||
"""
|
||||
A Prior Transformer model.
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
||||
num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
|
||||
embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states`
|
||||
num_embeddings (`int`, *optional*, defaults to 77):
|
||||
The number of embeddings of the model input `hidden_states`
|
||||
additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
|
||||
projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings +
|
||||
additional_embeddings`.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
time_embed_act_fn (`str`, *optional*, defaults to 'silu'):
|
||||
The activation function to use to create timestep embeddings.
|
||||
norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before
|
||||
passing to Transformer blocks. Set it to `None` if normalization is not needed.
|
||||
embedding_proj_norm_type (`str`, *optional*, defaults to None):
|
||||
The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not
|
||||
needed.
|
||||
encoder_hid_proj_type (`str`, *optional*, defaults to `linear`):
|
||||
The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if
|
||||
`encoder_hidden_states` is `None`.
|
||||
added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model.
|
||||
Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot
|
||||
product between the text embedding and image embedding as proposed in the unclip paper
|
||||
https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended.
|
||||
time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings.
|
||||
If None, will be set to `num_attention_heads * attention_head_dim`
|
||||
embedding_proj_dim (`int`, *optional*, default to None):
|
||||
The dimension of `proj_embedding`. If None, will be set to `embedding_dim`.
|
||||
clip_embed_dim (`int`, *optional*, default to None):
|
||||
The dimension of the output. If None, will be set to `embedding_dim`.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 32,
|
||||
attention_head_dim: int = 64,
|
||||
num_layers: int = 20,
|
||||
embedding_dim: int = 768,
|
||||
num_embeddings=77,
|
||||
additional_embeddings=4,
|
||||
dropout: float = 0.0,
|
||||
time_embed_act_fn: str = "silu",
|
||||
norm_in_type: Optional[str] = None, # layer
|
||||
embedding_proj_norm_type: Optional[str] = None, # layer
|
||||
encoder_hid_proj_type: Optional[str] = "linear", # linear
|
||||
added_emb_type: Optional[str] = "prd", # prd
|
||||
time_embed_dim: Optional[int] = None,
|
||||
embedding_proj_dim: Optional[int] = None,
|
||||
clip_embed_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
self.additional_embeddings = additional_embeddings
|
||||
|
||||
time_embed_dim = time_embed_dim or inner_dim
|
||||
embedding_proj_dim = embedding_proj_dim or embedding_dim
|
||||
clip_embed_dim = clip_embed_dim or embedding_dim
|
||||
|
||||
self.time_proj = Timesteps(inner_dim, True, 0)
|
||||
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn)
|
||||
|
||||
self.proj_in = nn.Linear(embedding_dim, inner_dim)
|
||||
|
||||
if embedding_proj_norm_type is None:
|
||||
self.embedding_proj_norm = None
|
||||
elif embedding_proj_norm_type == "layer":
|
||||
self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim)
|
||||
else:
|
||||
raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}")
|
||||
|
||||
self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim)
|
||||
|
||||
if encoder_hid_proj_type is None:
|
||||
self.encoder_hidden_states_proj = None
|
||||
elif encoder_hid_proj_type == "linear":
|
||||
self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
|
||||
else:
|
||||
raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}")
|
||||
|
||||
self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim))
|
||||
|
||||
if added_emb_type == "prd":
|
||||
self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
|
||||
elif added_emb_type is None:
|
||||
self.prd_embedding = None
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`."
|
||||
)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
dropout=dropout,
|
||||
activation_fn="gelu",
|
||||
attention_bias=True,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
if norm_in_type == "layer":
|
||||
self.norm_in = nn.LayerNorm(inner_dim)
|
||||
elif norm_in_type is None:
|
||||
self.norm_in = None
|
||||
else:
|
||||
raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.")
|
||||
|
||||
self.norm_out = nn.LayerNorm(inner_dim)
|
||||
|
||||
self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim)
|
||||
|
||||
causal_attention_mask = torch.full(
|
||||
[num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0
|
||||
)
|
||||
causal_attention_mask.triu_(1)
|
||||
causal_attention_mask = causal_attention_mask[None, ...]
|
||||
self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False)
|
||||
|
||||
self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim))
|
||||
self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim))
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
"""
|
||||
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnAddedKVProcessor()
|
||||
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnProcessor()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
proj_embedding: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.BoolTensor] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
The [`PriorTransformer`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
|
||||
The currently predicted image embeddings.
|
||||
timestep (`torch.LongTensor`):
|
||||
Current denoising step.
|
||||
proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
|
||||
Projected embedding vector the denoising process is conditioned on.
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
|
||||
Hidden states of the text embeddings the denoising process is conditioned on.
|
||||
attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
|
||||
Text mask for the text embeddings.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.prior_transformer.PriorTransformerOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.prior_transformer.PriorTransformerOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.prior_transformer.PriorTransformerOutput`] is returned, otherwise a
|
||||
tuple is returned where the first element is the sample tensor.
|
||||
"""
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(hidden_states.device)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device)
|
||||
|
||||
timesteps_projected = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might be fp16, so we need to cast here.
|
||||
timesteps_projected = timesteps_projected.to(dtype=self.dtype)
|
||||
time_embeddings = self.time_embedding(timesteps_projected)
|
||||
|
||||
if self.embedding_proj_norm is not None:
|
||||
proj_embedding = self.embedding_proj_norm(proj_embedding)
|
||||
|
||||
proj_embeddings = self.embedding_proj(proj_embedding)
|
||||
if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None:
|
||||
encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
|
||||
elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None:
|
||||
raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set")
|
||||
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
positional_embeddings = self.positional_embedding.to(hidden_states.dtype)
|
||||
|
||||
additional_embeds = []
|
||||
additional_embeddings_len = 0
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
additional_embeds.append(encoder_hidden_states)
|
||||
additional_embeddings_len += encoder_hidden_states.shape[1]
|
||||
|
||||
if len(proj_embeddings.shape) == 2:
|
||||
proj_embeddings = proj_embeddings[:, None, :]
|
||||
|
||||
if len(hidden_states.shape) == 2:
|
||||
hidden_states = hidden_states[:, None, :]
|
||||
|
||||
additional_embeds = additional_embeds + [
|
||||
proj_embeddings,
|
||||
time_embeddings[:, None, :],
|
||||
hidden_states,
|
||||
]
|
||||
|
||||
if self.prd_embedding is not None:
|
||||
prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
|
||||
additional_embeds.append(prd_embedding)
|
||||
|
||||
hidden_states = torch.cat(
|
||||
additional_embeds,
|
||||
dim=1,
|
||||
)
|
||||
|
||||
# Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens
|
||||
additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1
|
||||
if positional_embeddings.shape[1] < hidden_states.shape[1]:
|
||||
positional_embeddings = F.pad(
|
||||
positional_embeddings,
|
||||
(
|
||||
0,
|
||||
0,
|
||||
additional_embeddings_len,
|
||||
self.prd_embedding.shape[1] if self.prd_embedding is not None else 0,
|
||||
),
|
||||
value=0.0,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + positional_embeddings
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
|
||||
attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
|
||||
attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
|
||||
|
||||
if self.norm_in is not None:
|
||||
hidden_states = self.norm_in(hidden_states)
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(hidden_states, attention_mask=attention_mask)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
|
||||
if self.prd_embedding is not None:
|
||||
hidden_states = hidden_states[:, -1]
|
||||
else:
|
||||
hidden_states = hidden_states[:, additional_embeddings_len:]
|
||||
|
||||
predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
return (predicted_image_embedding,)
|
||||
|
||||
return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding)
|
||||
|
||||
def post_process_latents(self, prior_latents):
|
||||
prior_latents = (prior_latents * self.clip_std) + self.clip_mean
|
||||
return prior_latents
|
||||
438
src/diffusers/models/transformers/t5_film_transformer.py
Normal file
438
src/diffusers/models/transformers/t5_film_transformer.py
Normal file
@@ -0,0 +1,438 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ..attention_processor import Attention
|
||||
from ..embeddings import get_timestep_embedding
|
||||
from ..modeling_utils import ModelMixin
|
||||
|
||||
|
||||
class T5FilmDecoder(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
T5 style decoder with FiLM conditioning.
|
||||
|
||||
Args:
|
||||
input_dims (`int`, *optional*, defaults to `128`):
|
||||
The number of input dimensions.
|
||||
targets_length (`int`, *optional*, defaults to `256`):
|
||||
The length of the targets.
|
||||
d_model (`int`, *optional*, defaults to `768`):
|
||||
Size of the input hidden states.
|
||||
num_layers (`int`, *optional*, defaults to `12`):
|
||||
The number of `DecoderLayer`'s to use.
|
||||
num_heads (`int`, *optional*, defaults to `12`):
|
||||
The number of attention heads to use.
|
||||
d_kv (`int`, *optional*, defaults to `64`):
|
||||
Size of the key-value projection vectors.
|
||||
d_ff (`int`, *optional*, defaults to `2048`):
|
||||
The number of dimensions in the intermediate feed-forward layer of `DecoderLayer`'s.
|
||||
dropout_rate (`float`, *optional*, defaults to `0.1`):
|
||||
Dropout probability.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
input_dims: int = 128,
|
||||
targets_length: int = 256,
|
||||
max_decoder_noise_time: float = 2000.0,
|
||||
d_model: int = 768,
|
||||
num_layers: int = 12,
|
||||
num_heads: int = 12,
|
||||
d_kv: int = 64,
|
||||
d_ff: int = 2048,
|
||||
dropout_rate: float = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.conditioning_emb = nn.Sequential(
|
||||
nn.Linear(d_model, d_model * 4, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(d_model * 4, d_model * 4, bias=False),
|
||||
nn.SiLU(),
|
||||
)
|
||||
|
||||
self.position_encoding = nn.Embedding(targets_length, d_model)
|
||||
self.position_encoding.weight.requires_grad = False
|
||||
|
||||
self.continuous_inputs_projection = nn.Linear(input_dims, d_model, bias=False)
|
||||
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
|
||||
self.decoders = nn.ModuleList()
|
||||
for lyr_num in range(num_layers):
|
||||
# FiLM conditional T5 decoder
|
||||
lyr = DecoderLayer(d_model=d_model, d_kv=d_kv, num_heads=num_heads, d_ff=d_ff, dropout_rate=dropout_rate)
|
||||
self.decoders.append(lyr)
|
||||
|
||||
self.decoder_norm = T5LayerNorm(d_model)
|
||||
|
||||
self.post_dropout = nn.Dropout(p=dropout_rate)
|
||||
self.spec_out = nn.Linear(d_model, input_dims, bias=False)
|
||||
|
||||
def encoder_decoder_mask(self, query_input: torch.FloatTensor, key_input: torch.FloatTensor) -> torch.FloatTensor:
|
||||
mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2))
|
||||
return mask.unsqueeze(-3)
|
||||
|
||||
def forward(self, encodings_and_masks, decoder_input_tokens, decoder_noise_time):
|
||||
batch, _, _ = decoder_input_tokens.shape
|
||||
assert decoder_noise_time.shape == (batch,)
|
||||
|
||||
# decoder_noise_time is in [0, 1), so rescale to expected timing range.
|
||||
time_steps = get_timestep_embedding(
|
||||
decoder_noise_time * self.config.max_decoder_noise_time,
|
||||
embedding_dim=self.config.d_model,
|
||||
max_period=self.config.max_decoder_noise_time,
|
||||
).to(dtype=self.dtype)
|
||||
|
||||
conditioning_emb = self.conditioning_emb(time_steps).unsqueeze(1)
|
||||
|
||||
assert conditioning_emb.shape == (batch, 1, self.config.d_model * 4)
|
||||
|
||||
seq_length = decoder_input_tokens.shape[1]
|
||||
|
||||
# If we want to use relative positions for audio context, we can just offset
|
||||
# this sequence by the length of encodings_and_masks.
|
||||
decoder_positions = torch.broadcast_to(
|
||||
torch.arange(seq_length, device=decoder_input_tokens.device),
|
||||
(batch, seq_length),
|
||||
)
|
||||
|
||||
position_encodings = self.position_encoding(decoder_positions)
|
||||
|
||||
inputs = self.continuous_inputs_projection(decoder_input_tokens)
|
||||
inputs += position_encodings
|
||||
y = self.dropout(inputs)
|
||||
|
||||
# decoder: No padding present.
|
||||
decoder_mask = torch.ones(
|
||||
decoder_input_tokens.shape[:2], device=decoder_input_tokens.device, dtype=inputs.dtype
|
||||
)
|
||||
|
||||
# Translate encoding masks to encoder-decoder masks.
|
||||
encodings_and_encdec_masks = [(x, self.encoder_decoder_mask(decoder_mask, y)) for x, y in encodings_and_masks]
|
||||
|
||||
# cross attend style: concat encodings
|
||||
encoded = torch.cat([x[0] for x in encodings_and_encdec_masks], dim=1)
|
||||
encoder_decoder_mask = torch.cat([x[1] for x in encodings_and_encdec_masks], dim=-1)
|
||||
|
||||
for lyr in self.decoders:
|
||||
y = lyr(
|
||||
y,
|
||||
conditioning_emb=conditioning_emb,
|
||||
encoder_hidden_states=encoded,
|
||||
encoder_attention_mask=encoder_decoder_mask,
|
||||
)[0]
|
||||
|
||||
y = self.decoder_norm(y)
|
||||
y = self.post_dropout(y)
|
||||
|
||||
spec_out = self.spec_out(y)
|
||||
return spec_out
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
r"""
|
||||
T5 decoder layer.
|
||||
|
||||
Args:
|
||||
d_model (`int`):
|
||||
Size of the input hidden states.
|
||||
d_kv (`int`):
|
||||
Size of the key-value projection vectors.
|
||||
num_heads (`int`):
|
||||
Number of attention heads.
|
||||
d_ff (`int`):
|
||||
Size of the intermediate feed-forward layer.
|
||||
dropout_rate (`float`):
|
||||
Dropout probability.
|
||||
layer_norm_epsilon (`float`, *optional*, defaults to `1e-6`):
|
||||
A small value used for numerical stability to avoid dividing by zero.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, d_model: int, d_kv: int, num_heads: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float = 1e-6
|
||||
):
|
||||
super().__init__()
|
||||
self.layer = nn.ModuleList()
|
||||
|
||||
# cond self attention: layer 0
|
||||
self.layer.append(
|
||||
T5LayerSelfAttentionCond(d_model=d_model, d_kv=d_kv, num_heads=num_heads, dropout_rate=dropout_rate)
|
||||
)
|
||||
|
||||
# cross attention: layer 1
|
||||
self.layer.append(
|
||||
T5LayerCrossAttention(
|
||||
d_model=d_model,
|
||||
d_kv=d_kv,
|
||||
num_heads=num_heads,
|
||||
dropout_rate=dropout_rate,
|
||||
layer_norm_epsilon=layer_norm_epsilon,
|
||||
)
|
||||
)
|
||||
|
||||
# Film Cond MLP + dropout: last layer
|
||||
self.layer.append(
|
||||
T5LayerFFCond(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate, layer_norm_epsilon=layer_norm_epsilon)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
conditioning_emb: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_decoder_position_bias=None,
|
||||
) -> Tuple[torch.FloatTensor]:
|
||||
hidden_states = self.layer[0](
|
||||
hidden_states,
|
||||
conditioning_emb=conditioning_emb,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_extended_attention_mask = torch.where(encoder_attention_mask > 0, 0, -1e10).to(
|
||||
encoder_hidden_states.dtype
|
||||
)
|
||||
|
||||
hidden_states = self.layer[1](
|
||||
hidden_states,
|
||||
key_value_states=encoder_hidden_states,
|
||||
attention_mask=encoder_extended_attention_mask,
|
||||
)
|
||||
|
||||
# Apply Film Conditional Feed Forward layer
|
||||
hidden_states = self.layer[-1](hidden_states, conditioning_emb)
|
||||
|
||||
return (hidden_states,)
|
||||
|
||||
|
||||
class T5LayerSelfAttentionCond(nn.Module):
|
||||
r"""
|
||||
T5 style self-attention layer with conditioning.
|
||||
|
||||
Args:
|
||||
d_model (`int`):
|
||||
Size of the input hidden states.
|
||||
d_kv (`int`):
|
||||
Size of the key-value projection vectors.
|
||||
num_heads (`int`):
|
||||
Number of attention heads.
|
||||
dropout_rate (`float`):
|
||||
Dropout probability.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float):
|
||||
super().__init__()
|
||||
self.layer_norm = T5LayerNorm(d_model)
|
||||
self.FiLMLayer = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
|
||||
self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
conditioning_emb: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
# pre_self_attention_layer_norm
|
||||
normed_hidden_states = self.layer_norm(hidden_states)
|
||||
|
||||
if conditioning_emb is not None:
|
||||
normed_hidden_states = self.FiLMLayer(normed_hidden_states, conditioning_emb)
|
||||
|
||||
# Self-attention block
|
||||
attention_output = self.attention(normed_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + self.dropout(attention_output)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class T5LayerCrossAttention(nn.Module):
|
||||
r"""
|
||||
T5 style cross-attention layer.
|
||||
|
||||
Args:
|
||||
d_model (`int`):
|
||||
Size of the input hidden states.
|
||||
d_kv (`int`):
|
||||
Size of the key-value projection vectors.
|
||||
num_heads (`int`):
|
||||
Number of attention heads.
|
||||
dropout_rate (`float`):
|
||||
Dropout probability.
|
||||
layer_norm_epsilon (`float`):
|
||||
A small value used for numerical stability to avoid dividing by zero.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float, layer_norm_epsilon: float):
|
||||
super().__init__()
|
||||
self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
|
||||
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
key_value_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
normed_hidden_states = self.layer_norm(hidden_states)
|
||||
attention_output = self.attention(
|
||||
normed_hidden_states,
|
||||
encoder_hidden_states=key_value_states,
|
||||
attention_mask=attention_mask.squeeze(1),
|
||||
)
|
||||
layer_output = hidden_states + self.dropout(attention_output)
|
||||
return layer_output
|
||||
|
||||
|
||||
class T5LayerFFCond(nn.Module):
|
||||
r"""
|
||||
T5 style feed-forward conditional layer.
|
||||
|
||||
Args:
|
||||
d_model (`int`):
|
||||
Size of the input hidden states.
|
||||
d_ff (`int`):
|
||||
Size of the intermediate feed-forward layer.
|
||||
dropout_rate (`float`):
|
||||
Dropout probability.
|
||||
layer_norm_epsilon (`float`):
|
||||
A small value used for numerical stability to avoid dividing by zero.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float):
|
||||
super().__init__()
|
||||
self.DenseReluDense = T5DenseGatedActDense(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate)
|
||||
self.film = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
|
||||
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, conditioning_emb: Optional[torch.FloatTensor] = None
|
||||
) -> torch.FloatTensor:
|
||||
forwarded_states = self.layer_norm(hidden_states)
|
||||
if conditioning_emb is not None:
|
||||
forwarded_states = self.film(forwarded_states, conditioning_emb)
|
||||
|
||||
forwarded_states = self.DenseReluDense(forwarded_states)
|
||||
hidden_states = hidden_states + self.dropout(forwarded_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class T5DenseGatedActDense(nn.Module):
|
||||
r"""
|
||||
T5 style feed-forward layer with gated activations and dropout.
|
||||
|
||||
Args:
|
||||
d_model (`int`):
|
||||
Size of the input hidden states.
|
||||
d_ff (`int`):
|
||||
Size of the intermediate feed-forward layer.
|
||||
dropout_rate (`float`):
|
||||
Dropout probability.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, d_ff: int, dropout_rate: float):
|
||||
super().__init__()
|
||||
self.wi_0 = nn.Linear(d_model, d_ff, bias=False)
|
||||
self.wi_1 = nn.Linear(d_model, d_ff, bias=False)
|
||||
self.wo = nn.Linear(d_ff, d_model, bias=False)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.act = NewGELUActivation()
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
hidden_gelu = self.act(self.wi_0(hidden_states))
|
||||
hidden_linear = self.wi_1(hidden_states)
|
||||
hidden_states = hidden_gelu * hidden_linear
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
hidden_states = self.wo(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class T5LayerNorm(nn.Module):
|
||||
r"""
|
||||
T5 style layer normalization module.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`):
|
||||
Size of the input hidden states.
|
||||
eps (`float`, `optional`, defaults to `1e-6`):
|
||||
A small value used for numerical stability to avoid dividing by zero.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size: int, eps: float = 1e-6):
|
||||
"""
|
||||
Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
|
||||
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
|
||||
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
|
||||
# half-precision inputs is done in fp32
|
||||
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
|
||||
# convert into half-precision if necessary
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
|
||||
return self.weight * hidden_states
|
||||
|
||||
|
||||
class NewGELUActivation(nn.Module):
|
||||
"""
|
||||
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
|
||||
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
|
||||
"""
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
|
||||
|
||||
|
||||
class T5FiLMLayer(nn.Module):
|
||||
"""
|
||||
T5 style FiLM Layer.
|
||||
|
||||
Args:
|
||||
in_features (`int`):
|
||||
Number of input features.
|
||||
out_features (`int`):
|
||||
Number of output features.
|
||||
"""
|
||||
|
||||
def __init__(self, in_features: int, out_features: int):
|
||||
super().__init__()
|
||||
self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False)
|
||||
|
||||
def forward(self, x: torch.FloatTensor, conditioning_emb: torch.FloatTensor) -> torch.FloatTensor:
|
||||
emb = self.scale_bias(conditioning_emb)
|
||||
scale, shift = torch.chunk(emb, 2, -1)
|
||||
x = x * (1 + scale) + shift
|
||||
return x
|
||||
458
src/diffusers/models/transformers/transformer_2d.py
Normal file
458
src/diffusers/models/transformers/transformer_2d.py
Normal file
@@ -0,0 +1,458 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
|
||||
from ..attention import BasicTransformerBlock
|
||||
from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection
|
||||
from ..lora import LoRACompatibleConv, LoRACompatibleLinear
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormSingle
|
||||
|
||||
|
||||
@dataclass
|
||||
class Transformer2DModelOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`Transformer2DModel`].
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
|
||||
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
|
||||
distributions for the unnoised latent pixels.
|
||||
"""
|
||||
|
||||
sample: torch.FloatTensor
|
||||
|
||||
|
||||
class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
A 2D Transformer model for image-like data.
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
||||
in_channels (`int`, *optional*):
|
||||
The number of channels in the input and output (specify if the input is **continuous**).
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
||||
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
||||
This is fixed during training since it is used to learn a number of position embeddings.
|
||||
num_vector_embeds (`int`, *optional*):
|
||||
The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
|
||||
Includes the class for the masked latent pixel.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
|
||||
num_embeds_ada_norm ( `int`, *optional*):
|
||||
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
|
||||
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
|
||||
added to the hidden states.
|
||||
|
||||
During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
|
||||
attention_bias (`bool`, *optional*):
|
||||
Configure if the `TransformerBlocks` attention should contain a bias parameter.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 88,
|
||||
in_channels: Optional[int] = None,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
dropout: float = 0.0,
|
||||
norm_num_groups: int = 32,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
sample_size: Optional[int] = None,
|
||||
num_vector_embeds: Optional[int] = None,
|
||||
patch_size: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
use_linear_projection: bool = False,
|
||||
only_cross_attention: bool = False,
|
||||
double_self_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
norm_type: str = "layer_norm",
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_eps: float = 1e-5,
|
||||
attention_type: str = "default",
|
||||
caption_channels: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_linear_projection = use_linear_projection
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
||||
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
||||
|
||||
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
|
||||
# Define whether input is continuous or discrete depending on configuration
|
||||
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
|
||||
self.is_input_vectorized = num_vector_embeds is not None
|
||||
self.is_input_patches = in_channels is not None and patch_size is not None
|
||||
|
||||
if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
|
||||
" incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
|
||||
" Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
|
||||
" results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
|
||||
" would be very nice if you could open a Pull request for the `transformer/config.json` file"
|
||||
)
|
||||
deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
|
||||
norm_type = "ada_norm"
|
||||
|
||||
if self.is_input_continuous and self.is_input_vectorized:
|
||||
raise ValueError(
|
||||
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
|
||||
" sure that either `in_channels` or `num_vector_embeds` is None."
|
||||
)
|
||||
elif self.is_input_vectorized and self.is_input_patches:
|
||||
raise ValueError(
|
||||
f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
|
||||
" sure that either `num_vector_embeds` or `num_patches` is None."
|
||||
)
|
||||
elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
|
||||
raise ValueError(
|
||||
f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
|
||||
f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
|
||||
)
|
||||
|
||||
# 2. Define input layers
|
||||
if self.is_input_continuous:
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
if use_linear_projection:
|
||||
self.proj_in = linear_cls(in_channels, inner_dim)
|
||||
else:
|
||||
self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
||||
elif self.is_input_vectorized:
|
||||
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
|
||||
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
|
||||
|
||||
self.height = sample_size
|
||||
self.width = sample_size
|
||||
self.num_vector_embeds = num_vector_embeds
|
||||
self.num_latent_pixels = self.height * self.width
|
||||
|
||||
self.latent_image_embedding = ImagePositionalEmbeddings(
|
||||
num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
|
||||
)
|
||||
elif self.is_input_patches:
|
||||
assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
|
||||
|
||||
self.height = sample_size
|
||||
self.width = sample_size
|
||||
|
||||
self.patch_size = patch_size
|
||||
interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
|
||||
interpolation_scale = max(interpolation_scale, 1)
|
||||
self.pos_embed = PatchEmbed(
|
||||
height=sample_size,
|
||||
width=sample_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dim=inner_dim,
|
||||
interpolation_scale=interpolation_scale,
|
||||
)
|
||||
|
||||
# 3. Define transformers blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
dropout=dropout,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
activation_fn=activation_fn,
|
||||
num_embeds_ada_norm=num_embeds_ada_norm,
|
||||
attention_bias=attention_bias,
|
||||
only_cross_attention=only_cross_attention,
|
||||
double_self_attention=double_self_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
norm_type=norm_type,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
norm_eps=norm_eps,
|
||||
attention_type=attention_type,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 4. Define output layers
|
||||
self.out_channels = in_channels if out_channels is None else out_channels
|
||||
if self.is_input_continuous:
|
||||
# TODO: should use out_channels for continuous projections
|
||||
if use_linear_projection:
|
||||
self.proj_out = linear_cls(inner_dim, in_channels)
|
||||
else:
|
||||
self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
elif self.is_input_vectorized:
|
||||
self.norm_out = nn.LayerNorm(inner_dim)
|
||||
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
|
||||
elif self.is_input_patches and norm_type != "ada_norm_single":
|
||||
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
|
||||
self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
|
||||
elif self.is_input_patches and norm_type == "ada_norm_single":
|
||||
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
|
||||
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
|
||||
|
||||
# 5. PixArt-Alpha blocks.
|
||||
self.adaln_single = None
|
||||
self.use_additional_conditions = False
|
||||
if norm_type == "ada_norm_single":
|
||||
self.use_additional_conditions = self.config.sample_size == 128
|
||||
# TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
|
||||
# additional conditions until we find better name
|
||||
self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
|
||||
|
||||
self.caption_projection = None
|
||||
if caption_channels is not None:
|
||||
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
||||
class_labels: Optional[torch.LongTensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
The [`Transformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
timestep ( `torch.LongTensor`, *optional*):
|
||||
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
||||
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
||||
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
||||
`AdaLayerZeroNorm`.
|
||||
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
attention_mask ( `torch.Tensor`, *optional*):
|
||||
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
||||
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
||||
negative values to the attention scores corresponding to "discard" tokens.
|
||||
encoder_attention_mask ( `torch.Tensor`, *optional*):
|
||||
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
|
||||
|
||||
* Mask `(batch, sequence_length)` True = keep, False = discard.
|
||||
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
|
||||
|
||||
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
|
||||
above. This bias will be added to the cross-attention scores.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
||||
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
||||
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
||||
# expects mask of shape:
|
||||
# [batch, key_tokens]
|
||||
# adds singleton query_tokens dimension:
|
||||
# [batch, 1, key_tokens]
|
||||
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
||||
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
||||
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
||||
if attention_mask is not None and attention_mask.ndim == 2:
|
||||
# assume that mask is expressed as:
|
||||
# (1 = keep, 0 = discard)
|
||||
# convert mask into a bias that can be added to attention scores:
|
||||
# (keep = +0, discard = -10000.0)
|
||||
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
||||
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
||||
|
||||
# Retrieve lora scale.
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
|
||||
# 1. Input
|
||||
if self.is_input_continuous:
|
||||
batch, _, height, width = hidden_states.shape
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
if not self.use_linear_projection:
|
||||
hidden_states = (
|
||||
self.proj_in(hidden_states, scale=lora_scale)
|
||||
if not USE_PEFT_BACKEND
|
||||
else self.proj_in(hidden_states)
|
||||
)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||
else:
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||
hidden_states = (
|
||||
self.proj_in(hidden_states, scale=lora_scale)
|
||||
if not USE_PEFT_BACKEND
|
||||
else self.proj_in(hidden_states)
|
||||
)
|
||||
|
||||
elif self.is_input_vectorized:
|
||||
hidden_states = self.latent_image_embedding(hidden_states)
|
||||
elif self.is_input_patches:
|
||||
height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
|
||||
hidden_states = self.pos_embed(hidden_states)
|
||||
|
||||
if self.adaln_single is not None:
|
||||
if self.use_additional_conditions and added_cond_kwargs is None:
|
||||
raise ValueError(
|
||||
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
|
||||
)
|
||||
batch_size = hidden_states.shape[0]
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
|
||||
# 2. Blocks
|
||||
if self.caption_projection is not None:
|
||||
batch_size = hidden_states.shape[0]
|
||||
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
timestep,
|
||||
cross_attention_kwargs,
|
||||
class_labels,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
timestep=timestep,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
class_labels=class_labels,
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
if self.is_input_continuous:
|
||||
if not self.use_linear_projection:
|
||||
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
hidden_states = (
|
||||
self.proj_out(hidden_states, scale=lora_scale)
|
||||
if not USE_PEFT_BACKEND
|
||||
else self.proj_out(hidden_states)
|
||||
)
|
||||
else:
|
||||
hidden_states = (
|
||||
self.proj_out(hidden_states, scale=lora_scale)
|
||||
if not USE_PEFT_BACKEND
|
||||
else self.proj_out(hidden_states)
|
||||
)
|
||||
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
output = hidden_states + residual
|
||||
elif self.is_input_vectorized:
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
logits = self.out(hidden_states)
|
||||
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
|
||||
logits = logits.permute(0, 2, 1)
|
||||
|
||||
# log(p(x_0))
|
||||
output = F.log_softmax(logits.double(), dim=1).float()
|
||||
|
||||
if self.is_input_patches:
|
||||
if self.config.norm_type != "ada_norm_single":
|
||||
conditioning = self.transformer_blocks[0].norm1.emb(
|
||||
timestep, class_labels, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
||||
hidden_states = self.proj_out_2(hidden_states)
|
||||
elif self.config.norm_type == "ada_norm_single":
|
||||
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
# Modulation
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
# unpatchify
|
||||
if self.adaln_single is None:
|
||||
height = width = int(hidden_states.shape[1] ** 0.5)
|
||||
hidden_states = hidden_states.reshape(
|
||||
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
|
||||
)
|
||||
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
||||
output = hidden_states.reshape(
|
||||
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
379
src/diffusers/models/transformers/transformer_temporal.py
Normal file
379
src/diffusers/models/transformers/transformer_temporal.py
Normal file
@@ -0,0 +1,379 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import BaseOutput
|
||||
from ..attention import BasicTransformerBlock, TemporalBasicTransformerBlock
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..resnet import AlphaBlender
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransformerTemporalModelOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`TransformerTemporalModel`].
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
|
||||
The hidden states output conditioned on `encoder_hidden_states` input.
|
||||
"""
|
||||
|
||||
sample: torch.FloatTensor
|
||||
|
||||
|
||||
class TransformerTemporalModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
A Transformer model for video-like data.
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
||||
in_channels (`int`, *optional*):
|
||||
The number of channels in the input and output (specify if the input is **continuous**).
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
||||
attention_bias (`bool`, *optional*):
|
||||
Configure if the `TransformerBlock` attention should contain a bias parameter.
|
||||
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
||||
This is fixed during training since it is used to learn a number of position embeddings.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`):
|
||||
Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported
|
||||
activation functions.
|
||||
norm_elementwise_affine (`bool`, *optional*):
|
||||
Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
|
||||
double_self_attention (`bool`, *optional*):
|
||||
Configure if each `TransformerBlock` should contain two self-attention layers.
|
||||
positional_embeddings: (`str`, *optional*):
|
||||
The type of positional embeddings to apply to the sequence input before passing use.
|
||||
num_positional_embeddings: (`int`, *optional*):
|
||||
The maximum length of the sequence over which to apply positional embeddings.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 88,
|
||||
in_channels: Optional[int] = None,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
dropout: float = 0.0,
|
||||
norm_num_groups: int = 32,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
sample_size: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
norm_elementwise_affine: bool = True,
|
||||
double_self_attention: bool = True,
|
||||
positional_embeddings: Optional[str] = None,
|
||||
num_positional_embeddings: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
self.proj_in = nn.Linear(in_channels, inner_dim)
|
||||
|
||||
# 3. Define transformers blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
dropout=dropout,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
activation_fn=activation_fn,
|
||||
attention_bias=attention_bias,
|
||||
double_self_attention=double_self_attention,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
positional_embeddings=positional_embeddings,
|
||||
num_positional_embeddings=num_positional_embeddings,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.proj_out = nn.Linear(inner_dim, in_channels)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.LongTensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
class_labels: torch.LongTensor = None,
|
||||
num_frames: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> TransformerTemporalModelOutput:
|
||||
"""
|
||||
The [`TransformerTemporal`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
||||
Input hidden_states.
|
||||
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
timestep ( `torch.LongTensor`, *optional*):
|
||||
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
||||
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
||||
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
||||
`AdaLayerZeroNorm`.
|
||||
num_frames (`int`, *optional*, defaults to 1):
|
||||
The number of frames to be processed per batch. This is used to reshape the hidden states.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
|
||||
If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
|
||||
returned, otherwise a `tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
# 1. Input
|
||||
batch_frames, channel, height, width = hidden_states.shape
|
||||
batch_size = batch_frames // num_frames
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
|
||||
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
# 2. Blocks
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
timestep=timestep,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
class_labels=class_labels,
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = (
|
||||
hidden_states[None, None, :]
|
||||
.reshape(batch_size, height, width, num_frames, channel)
|
||||
.permute(0, 3, 4, 1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
hidden_states = hidden_states.reshape(batch_frames, channel, height, width)
|
||||
|
||||
output = hidden_states + residual
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return TransformerTemporalModelOutput(sample=output)
|
||||
|
||||
|
||||
class TransformerSpatioTemporalModel(nn.Module):
|
||||
"""
|
||||
A Transformer model for video-like data.
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
||||
in_channels (`int`, *optional*):
|
||||
The number of channels in the input and output (specify if the input is **continuous**).
|
||||
out_channels (`int`, *optional*):
|
||||
The number of channels in the output (specify if the input is **continuous**).
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 88,
|
||||
in_channels: int = 320,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
self.inner_dim = inner_dim
|
||||
|
||||
# 2. Define input layers
|
||||
self.in_channels = in_channels
|
||||
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6)
|
||||
self.proj_in = nn.Linear(in_channels, inner_dim)
|
||||
|
||||
# 3. Define transformers blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
time_mix_inner_dim = inner_dim
|
||||
self.temporal_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
TemporalBasicTransformerBlock(
|
||||
inner_dim,
|
||||
time_mix_inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
time_embed_dim = in_channels * 4
|
||||
self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels)
|
||||
self.time_proj = Timesteps(in_channels, True, 0)
|
||||
self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy="learned_with_images")
|
||||
|
||||
# 4. Define output layers
|
||||
self.out_channels = in_channels if out_channels is None else out_channels
|
||||
# TODO: should use out_channels for continuous projections
|
||||
self.proj_out = nn.Linear(inner_dim, in_channels)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
image_only_indicator: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
||||
Input hidden_states.
|
||||
num_frames (`int`):
|
||||
The number of frames to be processed per batch. This is used to reshape the hidden states.
|
||||
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
image_only_indicator (`torch.LongTensor` of shape `(batch size, num_frames)`, *optional*):
|
||||
A tensor indicating whether the input contains only images. 1 indicates that the input contains only
|
||||
images, 0 indicates that the input contains video frames.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
|
||||
If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
|
||||
returned, otherwise a `tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
# 1. Input
|
||||
batch_frames, _, height, width = hidden_states.shape
|
||||
num_frames = image_only_indicator.shape[-1]
|
||||
batch_size = batch_frames // num_frames
|
||||
|
||||
time_context = encoder_hidden_states
|
||||
time_context_first_timestep = time_context[None, :].reshape(
|
||||
batch_size, num_frames, -1, time_context.shape[-1]
|
||||
)[:, 0]
|
||||
time_context = time_context_first_timestep[None, :].broadcast_to(
|
||||
height * width, batch_size, 1, time_context.shape[-1]
|
||||
)
|
||||
time_context = time_context.reshape(height * width * batch_size, 1, time_context.shape[-1])
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
num_frames_emb = torch.arange(num_frames, device=hidden_states.device)
|
||||
num_frames_emb = num_frames_emb.repeat(batch_size, 1)
|
||||
num_frames_emb = num_frames_emb.reshape(-1)
|
||||
t_emb = self.time_proj(num_frames_emb)
|
||||
|
||||
# `Timesteps` does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
||||
|
||||
emb = self.time_pos_embed(t_emb)
|
||||
emb = emb[:, None, :]
|
||||
|
||||
# 2. Blocks
|
||||
for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
block,
|
||||
hidden_states,
|
||||
None,
|
||||
encoder_hidden_states,
|
||||
None,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
|
||||
hidden_states_mix = hidden_states
|
||||
hidden_states_mix = hidden_states_mix + emb
|
||||
|
||||
hidden_states_mix = temporal_block(
|
||||
hidden_states_mix,
|
||||
num_frames=num_frames,
|
||||
encoder_hidden_states=time_context,
|
||||
)
|
||||
hidden_states = self.time_mixer(
|
||||
x_spatial=hidden_states,
|
||||
x_temporal=hidden_states_mix,
|
||||
image_only_indicator=image_only_indicator,
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
output = hidden_states + residual
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return TransformerTemporalModelOutput(sample=output)
|
||||
@@ -22,7 +22,6 @@ from ...utils import is_torch_version, logging
|
||||
from ...utils.torch_utils import apply_freeu
|
||||
from ..activations import get_activation
|
||||
from ..attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
|
||||
from ..dual_transformer_2d import DualTransformer2DModel
|
||||
from ..normalization import AdaGroupNorm
|
||||
from ..resnet import (
|
||||
Downsample2D,
|
||||
@@ -34,7 +33,8 @@ from ..resnet import (
|
||||
ResnetBlockCondNorm2D,
|
||||
Upsample2D,
|
||||
)
|
||||
from ..transformer_2d import Transformer2DModel
|
||||
from ..transformers.dual_transformer_2d import DualTransformer2DModel
|
||||
from ..transformers.transformer_2d import Transformer2DModel
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -20,7 +20,6 @@ from torch import nn
|
||||
from ...utils import is_torch_version
|
||||
from ...utils.torch_utils import apply_freeu
|
||||
from ..attention import Attention
|
||||
from ..dual_transformer_2d import DualTransformer2DModel
|
||||
from ..resnet import (
|
||||
Downsample2D,
|
||||
ResnetBlock2D,
|
||||
@@ -28,8 +27,9 @@ from ..resnet import (
|
||||
TemporalConvLayer,
|
||||
Upsample2D,
|
||||
)
|
||||
from ..transformer_2d import Transformer2DModel
|
||||
from ..transformer_temporal import (
|
||||
from ..transformers.dual_transformer_2d import DualTransformer2DModel
|
||||
from ..transformers.transformer_2d import Transformer2DModel
|
||||
from ..transformers.transformer_temporal import (
|
||||
TransformerSpatioTemporalModel,
|
||||
TransformerTemporalModel,
|
||||
)
|
||||
|
||||
@@ -33,7 +33,7 @@ from ..attention_processor import (
|
||||
)
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..transformer_temporal import TransformerTemporalModel
|
||||
from ..transformers.transformer_temporal import TransformerTemporalModel
|
||||
from .unet_3d_blocks import (
|
||||
CrossAttnDownBlock3D,
|
||||
CrossAttnUpBlock3D,
|
||||
|
||||
@@ -29,7 +29,7 @@ from ..attention_processor import (
|
||||
)
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..transformer_temporal import TransformerTemporalModel
|
||||
from ..transformers.transformer_temporal import TransformerTemporalModel
|
||||
from .unet_2d_blocks import UNetMidBlock2DCrossAttn
|
||||
from .unet_2d_condition import UNet2DConditionModel
|
||||
from .unet_3d_blocks import (
|
||||
|
||||
@@ -35,7 +35,7 @@ from ...models.embeddings import (
|
||||
)
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
||||
from ...models.transformer_2d import Transformer2DModel
|
||||
from ...models.transformers.transformer_2d import Transformer2DModel
|
||||
from ...models.unets.unet_2d_blocks import DownBlock2D, UpBlock2D
|
||||
from ...models.unets.unet_2d_condition import UNet2DConditionOutput
|
||||
from ...utils import BaseOutput, is_torch_version, logging
|
||||
|
||||
@@ -19,7 +19,6 @@ from ....models.attention_processor import (
|
||||
AttnAddedKVProcessor2_0,
|
||||
AttnProcessor,
|
||||
)
|
||||
from ....models.dual_transformer_2d import DualTransformer2DModel
|
||||
from ....models.embeddings import (
|
||||
GaussianFourierProjection,
|
||||
ImageHintTimeEmbedding,
|
||||
@@ -32,7 +31,8 @@ from ....models.embeddings import (
|
||||
Timesteps,
|
||||
)
|
||||
from ....models.resnet import ResnetBlockCondNorm2D
|
||||
from ....models.transformer_2d import Transformer2DModel
|
||||
from ....models.transformers.dual_transformer_2d import DualTransformer2DModel
|
||||
from ....models.transformers.transformer_2d import Transformer2DModel
|
||||
from ....models.unets.unet_2d_condition import UNet2DConditionOutput
|
||||
from ....utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ....utils.torch_utils import apply_freeu
|
||||
|
||||
@@ -10,7 +10,7 @@ from ...models.attention import FeedForward
|
||||
from ...models.attention_processor import Attention
|
||||
from ...models.embeddings import TimestepEmbedding, Timesteps, get_2d_sincos_pos_embed
|
||||
from ...models.normalization import AdaLayerNorm
|
||||
from ...models.transformer_2d import Transformer2DModelOutput
|
||||
from ...models.transformers.transformer_2d import Transformer2DModelOutput
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU
|
||||
from diffusers.models.embeddings import get_timestep_embedding
|
||||
from diffusers.models.lora import LoRACompatibleLinear
|
||||
from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
||||
from diffusers.models.transformer_2d import Transformer2DModel
|
||||
from diffusers.models.transformers.transformer_2d import Transformer2DModel
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_manual_seed,
|
||||
require_torch_accelerator_with_fp64,
|
||||
|
||||
0
tests/models/transformers/__init__.py
Normal file
0
tests/models/transformers/__init__.py
Normal file
@@ -30,7 +30,7 @@ from diffusers.utils.testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from .test_modeling_common import ModelTesterMixin
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
Reference in New Issue
Block a user