mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -1,247 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ..utils import logging
|
||||
from ..utils.torch_utils import maybe_allow_in_graph
|
||||
from .attention_processor import (
|
||||
AttentionModuleMixin,
|
||||
FusedJointAttnProcessorSDPA,
|
||||
JointAttnProcessorSDPA,
|
||||
SanaLinearAttnProcessorSDPA,
|
||||
)
|
||||
from .normalization import get_normalization
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class SanaAttention(nn.Module, AttentionModuleMixin):
|
||||
"""
|
||||
Attention implementation specialized for Sana models.
|
||||
|
||||
This module implements lightweight multi-scale linear attention as used in Sana.
|
||||
|
||||
Args:
|
||||
in_channels (`int`): Number of input channels.
|
||||
out_channels (`int`): Number of output channels.
|
||||
num_attention_heads (`int`, *optional*): Number of attention heads.
|
||||
attention_head_dim (`int`, defaults to 8): Dimension of each attention head.
|
||||
mult (`float`, defaults to 1.0): Multiplier for inner dimension.
|
||||
norm_type (`str`, defaults to "batch_norm"): Type of normalization.
|
||||
kernel_sizes (`Tuple[int, ...]`, defaults to (5,)): Kernel sizes for multi-scale attention.
|
||||
"""
|
||||
|
||||
# Set Sana-specific processor classes
|
||||
default_processor_class = SanaLinearAttnProcessorSDPA
|
||||
fused_processor_class = None # Sana doesn't have a fused processor yet
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: int = 8,
|
||||
mult: float = 1.0,
|
||||
norm_type: str = "batch_norm",
|
||||
kernel_sizes: Tuple[int, ...] = (5,),
|
||||
eps: float = 1e-15,
|
||||
residual_connection: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Core parameters
|
||||
self.eps = eps
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.norm_type = norm_type
|
||||
self.residual_connection = residual_connection
|
||||
|
||||
# Calculate dimensions
|
||||
num_attention_heads = (
|
||||
int(in_channels // attention_head_dim * mult) if num_attention_heads is None else num_attention_heads
|
||||
)
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
self.inner_dim = inner_dim
|
||||
self.heads = num_attention_heads
|
||||
|
||||
# Query, key, value projections
|
||||
self.to_q = nn.Linear(in_channels, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(in_channels, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(in_channels, inner_dim, bias=False)
|
||||
|
||||
# Multi-scale attention
|
||||
self.to_qkv_multiscale = nn.ModuleList()
|
||||
for kernel_size in kernel_sizes:
|
||||
self.to_qkv_multiscale.append(
|
||||
SanaMultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size)
|
||||
)
|
||||
|
||||
# Output layers
|
||||
self.nonlinearity = nn.ReLU()
|
||||
self.to_out = nn.Linear(inner_dim * (1 + len(kernel_sizes)), out_channels, bias=False)
|
||||
self.norm_out = get_normalization(norm_type, num_features=out_channels)
|
||||
|
||||
# Set default processor
|
||||
self.fused_projections = False
|
||||
self.set_processor(self.default_processor_class())
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Process linear attention for Sana model inputs."""
|
||||
return self.processor(self, hidden_states)
|
||||
|
||||
|
||||
class SanaMultiscaleAttentionProjection(nn.Module):
|
||||
"""Projection layer for Sana multi-scale attention."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
num_attention_heads: int,
|
||||
kernel_size: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
channels = 3 * in_channels
|
||||
self.proj_in = nn.Conv2d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2,
|
||||
groups=channels,
|
||||
bias=False,
|
||||
)
|
||||
self.proj_out = nn.Conv2d(channels, channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class SD3Attention(nn.Module, AttentionModuleMixin):
|
||||
"""
|
||||
Attention implementation specialized for SD3 models.
|
||||
|
||||
This module implements the joint attention mechanism used in SD3,
|
||||
with native support for context pre-processing.
|
||||
|
||||
Args:
|
||||
query_dim (`int`): Number of channels in query.
|
||||
cross_attention_dim (`int`, *optional*): Number of channels in encoder states.
|
||||
heads (`int`, defaults to 8): Number of attention heads.
|
||||
dim_head (`int`, defaults to 64): Dimension of each attention head.
|
||||
dropout (`float`, defaults to 0.0): Dropout probability.
|
||||
bias (`bool`, defaults to False): Whether to use bias in linear projections.
|
||||
added_kv_proj_dim (`int`, *optional*): Dimension for added key/value projections.
|
||||
"""
|
||||
|
||||
# Set SD3-specific processor classes
|
||||
default_processor_class = JointAttnProcessorSDPA
|
||||
fused_processor_class = FusedJointAttnProcessorSDPA
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
bias: bool = False,
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
context_pre_only: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Core parameters
|
||||
self.inner_dim = dim_head * heads
|
||||
self.query_dim = query_dim
|
||||
self.heads = heads
|
||||
self.scale = dim_head**-0.5
|
||||
self.use_bias = bias
|
||||
self.scale_qk = True
|
||||
self.context_pre_only = context_pre_only
|
||||
|
||||
# Cross-attention setup
|
||||
self.is_cross_attention = cross_attention_dim is not None
|
||||
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
||||
|
||||
# Projections for self-attention
|
||||
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
|
||||
# Added projections for context processing
|
||||
self.added_kv_proj_dim = added_kv_proj_dim
|
||||
if added_kv_proj_dim is not None:
|
||||
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias)
|
||||
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias)
|
||||
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias)
|
||||
self.added_proj_bias = bias
|
||||
|
||||
# Output projection
|
||||
self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, query_dim, bias=bias), nn.Dropout(dropout)])
|
||||
|
||||
# Context output projection
|
||||
if added_kv_proj_dim is not None and not context_pre_only:
|
||||
self.to_add_out = nn.Linear(self.inner_dim, query_dim, bias=bias)
|
||||
else:
|
||||
self.to_add_out = None
|
||||
|
||||
# Set default processor and fusion state
|
||||
self.fused_projections = False
|
||||
self.set_processor(self.default_processor_class())
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""Process joint attention for SD3 model inputs."""
|
||||
# Filter parameters to only those expected by the processor
|
||||
processor_params = inspect.signature(self.processor.__call__).parameters.keys()
|
||||
quiet_params = {"ip_adapter_masks", "ip_hidden_states"}
|
||||
|
||||
# Check for unexpected parameters
|
||||
unexpected_params = [k for k, _ in kwargs.items() if k not in processor_params and k not in quiet_params]
|
||||
if unexpected_params:
|
||||
logger.warning(
|
||||
f"Parameters {unexpected_params} are not expected by {self.processor.__class__.__name__} and will be ignored."
|
||||
)
|
||||
|
||||
# Filter to only expected parameters
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if k in processor_params}
|
||||
|
||||
# Process with appropriate processor
|
||||
return self.processor(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
**filtered_kwargs,
|
||||
)
|
||||
|
||||
@@ -22,7 +22,7 @@ from torch import nn
|
||||
from ..image_processor import IPAdapterMaskProcessor
|
||||
from ..utils import deprecate, is_torch_xla_available, logging
|
||||
from ..utils.import_utils import is_torch_npu_available, is_torch_xla_version, is_xformers_available
|
||||
from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph
|
||||
from ..utils.torch_utils import is_torch_version
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -46,596 +46,6 @@ else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
class AttentionModuleMixin:
|
||||
"""
|
||||
A mixin class that provides common methods for attention modules.
|
||||
|
||||
This mixin adds functionality to set different attention processors, handle attention masks, compute attention
|
||||
scores, and manage projections.
|
||||
"""
|
||||
|
||||
# Default processor classes to be overridden by subclasses
|
||||
default_processor_cls = None
|
||||
_available_processors = []
|
||||
|
||||
def _get_compatible_processor(self, backend):
|
||||
for processor_cls in self._available_processors:
|
||||
if backend in processor_cls.compatible_backends:
|
||||
processor = processor_cls()
|
||||
return processor
|
||||
|
||||
def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
|
||||
"""
|
||||
Set whether to use NPU flash attention from `torch_npu` or not.
|
||||
|
||||
Args:
|
||||
use_npu_flash_attention (`bool`): Whether to use NPU flash attention or not.
|
||||
"""
|
||||
processor = self.default_processor_cls()
|
||||
|
||||
if use_npu_flash_attention:
|
||||
processor = self._get_compatible_processor("npu")
|
||||
|
||||
self.set_processor(processor)
|
||||
|
||||
def set_use_xla_flash_attention(
|
||||
self,
|
||||
use_xla_flash_attention: bool,
|
||||
partition_spec: Optional[Tuple[Optional[str], ...]] = None,
|
||||
is_flux=False,
|
||||
) -> None:
|
||||
"""
|
||||
Set whether to use XLA flash attention from `torch_xla` or not.
|
||||
|
||||
Args:
|
||||
use_xla_flash_attention (`bool`):
|
||||
Whether to use pallas flash attention kernel from `torch_xla` or not.
|
||||
partition_spec (`Tuple[]`, *optional*):
|
||||
Specify the partition specification if using SPMD. Otherwise None.
|
||||
is_flux (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model is a Flux model.
|
||||
"""
|
||||
processor = self.default_processor_cls()
|
||||
if use_xla_flash_attention:
|
||||
if not is_torch_xla_available():
|
||||
raise "torch_xla is not available"
|
||||
elif is_torch_xla_version("<", "2.3"):
|
||||
raise "flash attention pallas kernel is supported from torch_xla version 2.3"
|
||||
elif is_spmd() and is_torch_xla_version("<", "2.4"):
|
||||
raise "flash attention pallas kernel using SPMD is supported from torch_xla version 2.4"
|
||||
else:
|
||||
processor = self._get_compatible_processor("xla")
|
||||
|
||||
self.set_processor(processor)
|
||||
|
||||
@torch.no_grad()
|
||||
def fuse_projections(self, fuse=True):
|
||||
"""
|
||||
Fuse the query, key, and value projections into a single projection for efficiency.
|
||||
|
||||
Args:
|
||||
fuse (`bool`): Whether to fuse the projections or not.
|
||||
"""
|
||||
# Skip if already in desired state
|
||||
if getattr(self, "fused_projections", False) == fuse:
|
||||
return
|
||||
|
||||
device = self.to_q.weight.data.device
|
||||
dtype = self.to_q.weight.data.dtype
|
||||
|
||||
if not self.is_cross_attention:
|
||||
# Fuse self-attention projections
|
||||
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
|
||||
in_features = concatenated_weights.shape[1]
|
||||
out_features = concatenated_weights.shape[0]
|
||||
|
||||
self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
|
||||
self.to_qkv.weight.copy_(concatenated_weights)
|
||||
if self.use_bias:
|
||||
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
|
||||
self.to_qkv.bias.copy_(concatenated_bias)
|
||||
|
||||
else:
|
||||
# Fuse cross-attention key-value projections
|
||||
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
|
||||
in_features = concatenated_weights.shape[1]
|
||||
out_features = concatenated_weights.shape[0]
|
||||
|
||||
self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
|
||||
self.to_kv.weight.copy_(concatenated_weights)
|
||||
if self.use_bias:
|
||||
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
|
||||
self.to_kv.bias.copy_(concatenated_bias)
|
||||
|
||||
# Handle added projections for models like SD3, Flux, etc.
|
||||
if (
|
||||
getattr(self, "add_q_proj", None) is not None
|
||||
and getattr(self, "add_k_proj", None) is not None
|
||||
and getattr(self, "add_v_proj", None) is not None
|
||||
):
|
||||
concatenated_weights = torch.cat(
|
||||
[self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
|
||||
)
|
||||
in_features = concatenated_weights.shape[1]
|
||||
out_features = concatenated_weights.shape[0]
|
||||
|
||||
self.to_added_qkv = nn.Linear(
|
||||
in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
|
||||
)
|
||||
self.to_added_qkv.weight.copy_(concatenated_weights)
|
||||
if self.added_proj_bias:
|
||||
concatenated_bias = torch.cat(
|
||||
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
|
||||
)
|
||||
self.to_added_qkv.bias.copy_(concatenated_bias)
|
||||
|
||||
self.fused_projections = fuse
|
||||
self.processor.is_fused = fuse
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(
|
||||
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
|
||||
) -> None:
|
||||
"""
|
||||
Set whether to use memory efficient attention from `xformers` or not.
|
||||
|
||||
Args:
|
||||
use_memory_efficient_attention_xformers (`bool`):
|
||||
Whether to use memory efficient attention from `xformers` or not.
|
||||
attention_op (`Callable`, *optional*):
|
||||
The attention operation to use. Defaults to `None` which uses the default attention operation from
|
||||
`xformers`.
|
||||
"""
|
||||
is_custom_diffusion = hasattr(self, "processor") and isinstance(
|
||||
self.processor,
|
||||
(CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessorSDPA),
|
||||
)
|
||||
is_added_kv_processor = hasattr(self, "processor") and isinstance(
|
||||
self.processor,
|
||||
(
|
||||
AttnAddedKVProcessor,
|
||||
AttnAddedKVProcessorSDPA,
|
||||
SlicedAttnAddedKVProcessor,
|
||||
XFormersAttnAddedKVProcessor,
|
||||
),
|
||||
)
|
||||
is_ip_adapter = hasattr(self, "processor") and isinstance(
|
||||
self.processor,
|
||||
(IPAdapterAttnProcessor, IPAdapterAttnProcessorSDPA, IPAdapterXFormersAttnProcessor),
|
||||
)
|
||||
is_joint_processor = hasattr(self, "processor") and isinstance(
|
||||
self.processor,
|
||||
(
|
||||
JointAttnProcessorSDPA,
|
||||
XFormersJointAttnProcessor,
|
||||
),
|
||||
)
|
||||
|
||||
if use_memory_efficient_attention_xformers:
|
||||
if is_added_kv_processor and is_custom_diffusion:
|
||||
raise NotImplementedError(
|
||||
f"Memory efficient attention is currently not supported for custom diffusion for attention processor type {self.processor}"
|
||||
)
|
||||
if not is_xformers_available():
|
||||
raise ModuleNotFoundError(
|
||||
(
|
||||
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
||||
" xformers"
|
||||
),
|
||||
name="xformers",
|
||||
)
|
||||
elif not torch.cuda.is_available():
|
||||
raise ValueError(
|
||||
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
|
||||
" only available for GPU "
|
||||
)
|
||||
else:
|
||||
try:
|
||||
# Make sure we can run the memory efficient attention
|
||||
dtype = None
|
||||
if attention_op is not None:
|
||||
op_fw, op_bw = attention_op
|
||||
dtype, *_ = op_fw.SUPPORTED_DTYPES
|
||||
q = torch.randn((1, 2, 40), device="cuda", dtype=dtype)
|
||||
_ = xformers.ops.memory_efficient_attention(q, q, q)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
if is_custom_diffusion:
|
||||
processor = CustomDiffusionXFormersAttnProcessor(
|
||||
train_kv=self.processor.train_kv,
|
||||
train_q_out=self.processor.train_q_out,
|
||||
hidden_size=self.processor.hidden_size,
|
||||
cross_attention_dim=self.processor.cross_attention_dim,
|
||||
attention_op=attention_op,
|
||||
)
|
||||
processor.load_state_dict(self.processor.state_dict())
|
||||
if hasattr(self.processor, "to_k_custom_diffusion"):
|
||||
processor.to(self.processor.to_k_custom_diffusion.weight.device)
|
||||
elif is_added_kv_processor:
|
||||
# TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
|
||||
# which uses this type of cross attention ONLY because the attention mask of format
|
||||
# [0, ..., -10.000, ..., 0, ...,] is not supported
|
||||
# throw warning
|
||||
logger.info(
|
||||
"Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
|
||||
)
|
||||
processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
|
||||
elif is_ip_adapter:
|
||||
processor = IPAdapterXFormersAttnProcessor(
|
||||
hidden_size=self.processor.hidden_size,
|
||||
cross_attention_dim=self.processor.cross_attention_dim,
|
||||
num_tokens=self.processor.num_tokens,
|
||||
scale=self.processor.scale,
|
||||
attention_op=attention_op,
|
||||
)
|
||||
processor.load_state_dict(self.processor.state_dict())
|
||||
if hasattr(self.processor, "to_k_ip"):
|
||||
processor.to(
|
||||
device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype
|
||||
)
|
||||
elif is_joint_processor:
|
||||
processor = XFormersJointAttnProcessor(attention_op=attention_op)
|
||||
else:
|
||||
processor = XFormersAttnProcessor(attention_op=attention_op)
|
||||
else:
|
||||
if is_custom_diffusion:
|
||||
attn_processor_class = (
|
||||
CustomDiffusionAttnProcessorSDPA
|
||||
if hasattr(F, "scaled_dot_product_attention")
|
||||
else CustomDiffusionAttnProcessor
|
||||
)
|
||||
processor = attn_processor_class(
|
||||
train_kv=self.processor.train_kv,
|
||||
train_q_out=self.processor.train_q_out,
|
||||
hidden_size=self.processor.hidden_size,
|
||||
cross_attention_dim=self.processor.cross_attention_dim,
|
||||
)
|
||||
processor.load_state_dict(self.processor.state_dict())
|
||||
if hasattr(self.processor, "to_k_custom_diffusion"):
|
||||
processor.to(self.processor.to_k_custom_diffusion.weight.device)
|
||||
elif is_ip_adapter:
|
||||
processor = IPAdapterAttnProcessorSDPA(
|
||||
hidden_size=self.processor.hidden_size,
|
||||
cross_attention_dim=self.processor.cross_attention_dim,
|
||||
num_tokens=self.processor.num_tokens,
|
||||
scale=self.processor.scale,
|
||||
)
|
||||
processor.load_state_dict(self.processor.state_dict())
|
||||
if hasattr(self.processor, "to_k_ip"):
|
||||
processor.to(
|
||||
device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype
|
||||
)
|
||||
else:
|
||||
# set attention processor
|
||||
# We use the AttnProcessorSDPA by default when torch 2.x is used which uses
|
||||
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
||||
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
||||
processor = (
|
||||
AttnProcessorSDPA()
|
||||
if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
|
||||
else AttnProcessor()
|
||||
)
|
||||
|
||||
self.set_processor(processor)
|
||||
|
||||
def set_attention_slice(self, slice_size: int) -> None:
|
||||
"""
|
||||
Set the slice size for attention computation.
|
||||
|
||||
Args:
|
||||
slice_size (`int`):
|
||||
The slice size for attention computation.
|
||||
"""
|
||||
if slice_size is not None and slice_size > self.sliceable_head_dim:
|
||||
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
|
||||
|
||||
if slice_size is not None and self.added_kv_proj_dim is not None:
|
||||
processor = SlicedAttnAddedKVProcessor(slice_size)
|
||||
elif slice_size is not None:
|
||||
processor = SlicedAttnProcessor(slice_size)
|
||||
elif self.added_kv_proj_dim is not None:
|
||||
processor = AttnAddedKVProcessor()
|
||||
else:
|
||||
# set attention processor
|
||||
# We use the AttnProcessorSDPA by default when torch 2.x is used which uses
|
||||
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
||||
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
||||
processor = (
|
||||
AttnProcessorSDPA()
|
||||
if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
|
||||
else AttnProcessor()
|
||||
)
|
||||
|
||||
self.set_processor(processor)
|
||||
|
||||
def set_processor(self, processor: "AttnProcessor") -> None:
|
||||
"""
|
||||
Set the attention processor to use.
|
||||
|
||||
Args:
|
||||
processor (`AttnProcessor`):
|
||||
The attention processor to use.
|
||||
"""
|
||||
# if current processor is in `self._modules` and if passed `processor` is not, we need to
|
||||
# pop `processor` from `self._modules`
|
||||
if (
|
||||
hasattr(self, "processor")
|
||||
and isinstance(self.processor, torch.nn.Module)
|
||||
and not isinstance(processor, torch.nn.Module)
|
||||
):
|
||||
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
|
||||
self._modules.pop("processor")
|
||||
|
||||
self.processor = processor
|
||||
|
||||
def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
|
||||
"""
|
||||
Get the attention processor in use.
|
||||
|
||||
Args:
|
||||
return_deprecated_lora (`bool`, *optional*, defaults to `False`):
|
||||
Set to `True` to return the deprecated LoRA attention processor.
|
||||
|
||||
Returns:
|
||||
"AttentionProcessor": The attention processor in use.
|
||||
"""
|
||||
if not return_deprecated_lora:
|
||||
return self.processor
|
||||
|
||||
def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`.
|
||||
|
||||
Args:
|
||||
tensor (`torch.Tensor`): The tensor to reshape.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The reshaped tensor.
|
||||
"""
|
||||
head_size = self.heads
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
||||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
||||
return tensor
|
||||
|
||||
def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
|
||||
"""
|
||||
Reshape the tensor for multi-head attention processing.
|
||||
|
||||
Args:
|
||||
tensor (`torch.Tensor`): The tensor to reshape.
|
||||
out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The reshaped tensor.
|
||||
"""
|
||||
head_size = self.heads
|
||||
if tensor.ndim == 3:
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
extra_dim = 1
|
||||
else:
|
||||
batch_size, extra_dim, seq_len, dim = tensor.shape
|
||||
tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)
|
||||
tensor = tensor.permute(0, 2, 1, 3)
|
||||
|
||||
if out_dim == 3:
|
||||
tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)
|
||||
|
||||
return tensor
|
||||
|
||||
def get_attention_scores(
|
||||
self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute the attention scores.
|
||||
|
||||
Args:
|
||||
query (`torch.Tensor`): The query tensor.
|
||||
key (`torch.Tensor`): The key tensor.
|
||||
attention_mask (`torch.Tensor`, *optional*): The attention mask to use.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The attention probabilities/scores.
|
||||
"""
|
||||
dtype = query.dtype
|
||||
if self.upcast_attention:
|
||||
query = query.float()
|
||||
key = key.float()
|
||||
|
||||
if attention_mask is None:
|
||||
baddbmm_input = torch.empty(
|
||||
query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
|
||||
)
|
||||
beta = 0
|
||||
else:
|
||||
baddbmm_input = attention_mask
|
||||
beta = 1
|
||||
|
||||
attention_scores = torch.baddbmm(
|
||||
baddbmm_input,
|
||||
query,
|
||||
key.transpose(-1, -2),
|
||||
beta=beta,
|
||||
alpha=self.scale,
|
||||
)
|
||||
del baddbmm_input
|
||||
|
||||
if self.upcast_softmax:
|
||||
attention_scores = attention_scores.float()
|
||||
|
||||
attention_probs = attention_scores.softmax(dim=-1)
|
||||
del attention_scores
|
||||
|
||||
attention_probs = attention_probs.to(dtype)
|
||||
|
||||
return attention_probs
|
||||
|
||||
def prepare_attention_mask(
|
||||
self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Prepare the attention mask for the attention computation.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`): The attention mask to prepare.
|
||||
target_length (`int`): The target length of the attention mask.
|
||||
batch_size (`int`): The batch size for repeating the attention mask.
|
||||
out_dim (`int`, *optional*, defaults to `3`): Output dimension.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The prepared attention mask.
|
||||
"""
|
||||
head_size = self.heads
|
||||
if attention_mask is None:
|
||||
return attention_mask
|
||||
|
||||
current_length: int = attention_mask.shape[-1]
|
||||
if current_length != target_length:
|
||||
if attention_mask.device.type == "mps":
|
||||
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
|
||||
# Instead, we can manually construct the padding tensor.
|
||||
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
|
||||
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
|
||||
attention_mask = torch.cat([attention_mask, padding], dim=2)
|
||||
else:
|
||||
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
|
||||
# we want to instead pad by (0, remaining_length), where remaining_length is:
|
||||
# remaining_length: int = target_length - current_length
|
||||
# TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
|
||||
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
||||
|
||||
if out_dim == 3:
|
||||
if attention_mask.shape[0] < batch_size * head_size:
|
||||
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
|
||||
elif out_dim == 4:
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
|
||||
|
||||
return attention_mask
|
||||
|
||||
def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Normalize the encoder hidden states.
|
||||
|
||||
Args:
|
||||
encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The normalized encoder hidden states.
|
||||
"""
|
||||
assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
|
||||
if isinstance(self.norm_cross, nn.LayerNorm):
|
||||
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
||||
elif isinstance(self.norm_cross, nn.GroupNorm):
|
||||
# Group norm norms along the channels dimension and expects
|
||||
# input to be in the shape of (N, C, *). In this case, we want
|
||||
# to norm along the hidden dimension, so we need to move
|
||||
# (batch_size, sequence_length, hidden_size) ->
|
||||
# (batch_size, hidden_size, sequence_length)
|
||||
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
||||
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
||||
else:
|
||||
assert False
|
||||
|
||||
return encoder_hidden_states
|
||||
|
||||
|
||||
class AttnProcessorSDPA:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("AttnProcessorSDPA requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "Attention",
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
residual = hidden_states
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class Attention(nn.Module, AttentionModuleMixin):
|
||||
default_processor_class = AttnProcessorSDPA
|
||||
_available_processors = []
|
||||
@@ -893,11 +303,7 @@ class Attention(nn.Module, AttentionModuleMixin):
|
||||
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
||||
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
||||
if processor is None:
|
||||
processor = (
|
||||
AttnProcessorSDPA()
|
||||
if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
|
||||
else AttnProcessor()
|
||||
)
|
||||
processor = self.default_processor_class()
|
||||
self.set_processor(processor)
|
||||
|
||||
def forward(
|
||||
@@ -947,97 +353,99 @@ class Attention(nn.Module, AttentionModuleMixin):
|
||||
)
|
||||
|
||||
|
||||
class SanaMultiscaleAttentionProjection(nn.Module):
|
||||
def __init__(
|
||||
class AttnProcessorSDPA:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("AttnProcessorSDPA requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
in_channels: int,
|
||||
num_attention_heads: int,
|
||||
kernel_size: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
attn: "Attention",
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
channels = 3 * in_channels
|
||||
self.proj_in = nn.Conv2d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2,
|
||||
groups=channels,
|
||||
bias=False,
|
||||
residual = hidden_states
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
self.proj_out = nn.Conv2d(channels, channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
return hidden_states
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
class SanaMultiscaleLinearAttention(nn.Module):
|
||||
r"""Lightweight multi-scale linear attention"""
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: int = 8,
|
||||
mult: float = 1.0,
|
||||
norm_type: str = "batch_norm",
|
||||
kernel_sizes: Tuple[int, ...] = (5,),
|
||||
eps: float = 1e-15,
|
||||
residual_connection: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
# To prevent circular import
|
||||
from .normalization import get_normalization
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
self.eps = eps
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.norm_type = norm_type
|
||||
self.residual_connection = residual_connection
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
num_attention_heads = (
|
||||
int(in_channels // attention_head_dim * mult) if num_attention_heads is None else num_attention_heads
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.to_q = nn.Linear(in_channels, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(in_channels, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(in_channels, inner_dim, bias=False)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
self.to_qkv_multiscale = nn.ModuleList()
|
||||
for kernel_size in kernel_sizes:
|
||||
self.to_qkv_multiscale.append(
|
||||
SanaMultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size)
|
||||
)
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
self.nonlinearity = nn.ReLU()
|
||||
self.to_out = nn.Linear(inner_dim * (1 + len(kernel_sizes)), out_channels, bias=False)
|
||||
self.norm_out = get_normalization(norm_type, num_features=out_channels)
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
self.processor = SanaMultiscaleAttnProcessorSDPA()
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
def apply_linear_attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
|
||||
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1) # Adds padding
|
||||
scores = torch.matmul(value, key.transpose(-1, -2))
|
||||
hidden_states = torch.matmul(scores, query)
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
hidden_states = hidden_states.to(dtype=torch.float32)
|
||||
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
|
||||
return hidden_states
|
||||
|
||||
def apply_quadratic_attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
|
||||
scores = torch.matmul(key.transpose(-1, -2), query)
|
||||
scores = scores.to(dtype=torch.float32)
|
||||
scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps)
|
||||
hidden_states = torch.matmul(value, scores.to(value.dtype))
|
||||
return hidden_states
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
return self.processor(self, hidden_states)
|
||||
|
||||
|
||||
class CustomDiffusionAttnProcessor(nn.Module):
|
||||
r"""
|
||||
@@ -5304,98 +4712,104 @@ class StableAudioAttnProcessor2_0:
|
||||
def __new__(self, *args, **kwargs):
|
||||
deprecation_message = "`StableAudioAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `StableAudioAttnProcessorSDPA`"
|
||||
deprecate("StableAudioAttnProcessor2_0", "1.0.0", deprecation_message)
|
||||
|
||||
return StableAudioAttnProcessorSDPA(*args, **kwargs)
|
||||
|
||||
|
||||
class HunyuanAttnProcessor2_0(HunyuanAttnProcessorSDPA):
|
||||
class HunyuanAttnProcessor2_0:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "`HunyuanAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `HunyuanAttnProcessorSDPA`"
|
||||
deprecate("HunyuanAttnProcessor2_0", "1.0.0", deprecation_message)
|
||||
|
||||
return HunyuanAttnProcessorSDPA(*args, **kwargs)
|
||||
|
||||
|
||||
class FusedHunyuanAttnProcessor2_0(FusedHunyuanAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
class FusedHunyuanAttnProcessor2_0:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "`FusedHunyuanAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FusedHunyuanAttnProcessorSDPA`"
|
||||
deprecate("FusedHunyuanAttnProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
return HunyuanAttnProcessorSDPA(*args, **kwargs)
|
||||
|
||||
|
||||
class PAGHunyuanAttnProcessor2_0(PAGHunyuanAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
class PAGHunyuanAttnProcessor2_0:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "`PAGHunyuanAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGHunyuanAttnProcessorSDPA`"
|
||||
deprecate("PAGHunyuanAttnProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
return PAGHunyuanAttnProcessorSDPA(*args, **kwargs)
|
||||
|
||||
|
||||
class PAGCFGHunyuanAttnProcessor2_0(PAGCFGHunyuanAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
class PAGCFGHunyuanAttnProcessor2_0:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "`PAGCFGHunyuanAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGCFGHunyuanAttnProcessorSDPA`"
|
||||
deprecate("PAGCFGHunyuanAttnProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
return PAGCFGHunyuanAttnProcessorSDPA(*args, **kwargs)
|
||||
|
||||
|
||||
class LuminaAttnProcessor2_0(LuminaAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
class LuminaAttnProcessor2_0:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "`LuminaAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `LuminaAttnProcessorSDPA`"
|
||||
deprecate("LuminaAttnProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
return LuminaAttnProcessorSDPA(*args, **kwargs)
|
||||
|
||||
|
||||
class FusedAttnProcessor2_0(FusedAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "`FusedAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FusedAttnProcessorSDPA`"
|
||||
deprecate("FusedAttnProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class PAGIdentitySelfAttnProcessor2_0(PAGIdentitySelfAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
class PAGIdentitySelfAttnProcessor2_0:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "`PAGIdentitySelfAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGIdentitySelfAttnProcessorSDPA`"
|
||||
deprecate("PAGIdentitySelfAttnProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
return PAGIdentitySelfAttnProcessorSDPA(*args, **kwargs)
|
||||
|
||||
|
||||
class PAGCFGIdentitySelfAttnProcessor2_0(PAGCFGIdentitySelfAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
class PAGCFGIdentitySelfAttnProcessor2_0:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "`PAGCFGIdentitySelfAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGCFGIdentitySelfAttnProcessorSDPA`"
|
||||
deprecate("PAGCFGIdentitySelfAttnProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
return PAGCFGIdentitySelfAttnProcessorSDPA(*args, **kwargs)
|
||||
|
||||
|
||||
class SanaMultiscaleAttnProcessor2_0(SanaMultiscaleAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
class SanaMultiscaleAttnProcessor2_0:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "`SanaMultiscaleAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `SanaMultiscaleAttnProcessorSDPA`"
|
||||
deprecate("SanaMultiscaleAttnProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
return SanaMultiscaleAttnProcessorSDPA(*args, **kwargs)
|
||||
|
||||
|
||||
class LoRAAttnProcessor2_0(LoRAAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
class LoRAAttnProcessor2_0:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "`LoRAAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `LoRAAttnProcessorSDPA`"
|
||||
deprecate("LoRAAttnProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
return LoRAAttnProcessorSDPA(*args, **kwargs)
|
||||
|
||||
|
||||
class SanaLinearAttnProcessor2_0(SanaLinearAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
class SanaLinearAttnProcessor2_0:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "`SanaLinearAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `SanaLinearAttnProcessorSDPA`"
|
||||
deprecate("SanaLinearAttnProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
return SanaLinearAttnProcessorSDPA(*args, **kwargs)
|
||||
|
||||
|
||||
class PAGCFGSanaLinearAttnProcessor2_0(PAGCFGSanaLinearAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
class PAGCFGSanaLinearAttnProcessor2_0:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "`PAGCFGSanaLinearAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGCFGSanaLinearAttnProcessorSDPA`"
|
||||
deprecate("PAGCFGSanaLinearAttnProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
return PAGCFGSanaLinearAttnProcessorSDPA(*args, **kwargs)
|
||||
|
||||
|
||||
class PAGIdentitySanaLinearAttnProcessor2_0(PAGIdentitySanaLinearAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
class PAGIdentitySanaLinearAttnProcessor2_0:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "`PAGIdentitySanaLinearAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGIdentitySanaLinearAttnProcessorSDPA`"
|
||||
deprecate("PAGIdentitySanaLinearAttnProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
return PAGIdentitySanaLinearAttnProcessorSDPA(*args, **kwargs)
|
||||
|
||||
|
||||
class IPAdapterAttnProcessor(IPAdapterAttnProcessorSDPA):
|
||||
@@ -5405,11 +4819,12 @@ class IPAdapterAttnProcessor(IPAdapterAttnProcessorSDPA):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class IPAdapterAttnProcessor2_0(IPAdapterAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
class IPAdapterAttnProcessor2_0:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "`IPAdapterAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `IPAdapterAttnProcessorSDPA`"
|
||||
deprecate("IPAdapterAttnProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
return IPAdapterAttnProcessorSDPA(*args, **kwargs)
|
||||
|
||||
|
||||
ADDED_KV_ATTENTION_PROCESSORS = (
|
||||
|
||||
@@ -62,6 +62,98 @@ class ResBlock(nn.Module):
|
||||
return hidden_states + residual
|
||||
|
||||
|
||||
class SanaMultiscaleAttentionProjection(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
num_attention_heads: int,
|
||||
kernel_size: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
channels = 3 * in_channels
|
||||
self.proj_in = nn.Conv2d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2,
|
||||
groups=channels,
|
||||
bias=False,
|
||||
)
|
||||
self.proj_out = nn.Conv2d(channels, channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SanaMultiscaleLinearAttention(nn.Module):
|
||||
r"""Lightweight multi-scale linear attention"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: int = 8,
|
||||
mult: float = 1.0,
|
||||
norm_type: str = "batch_norm",
|
||||
kernel_sizes: Tuple[int, ...] = (5,),
|
||||
eps: float = 1e-15,
|
||||
residual_connection: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# To prevent circular import
|
||||
from ..normalization import get_normalization
|
||||
|
||||
self.eps = eps
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.norm_type = norm_type
|
||||
self.residual_connection = residual_connection
|
||||
|
||||
num_attention_heads = (
|
||||
int(in_channels // attention_head_dim * mult) if num_attention_heads is None else num_attention_heads
|
||||
)
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.to_q = nn.Linear(in_channels, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(in_channels, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(in_channels, inner_dim, bias=False)
|
||||
|
||||
self.to_qkv_multiscale = nn.ModuleList()
|
||||
for kernel_size in kernel_sizes:
|
||||
self.to_qkv_multiscale.append(
|
||||
SanaMultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size)
|
||||
)
|
||||
|
||||
self.nonlinearity = nn.ReLU()
|
||||
self.to_out = nn.Linear(inner_dim * (1 + len(kernel_sizes)), out_channels, bias=False)
|
||||
self.norm_out = get_normalization(norm_type, num_features=out_channels)
|
||||
|
||||
self.processor = SanaMultiscaleAttnProcessorSDPA()
|
||||
|
||||
def apply_linear_attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
|
||||
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1) # Adds padding
|
||||
scores = torch.matmul(value, key.transpose(-1, -2))
|
||||
hidden_states = torch.matmul(scores, query)
|
||||
|
||||
hidden_states = hidden_states.to(dtype=torch.float32)
|
||||
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
|
||||
return hidden_states
|
||||
|
||||
def apply_quadratic_attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
|
||||
scores = torch.matmul(key.transpose(-1, -2), query)
|
||||
scores = scores.to(dtype=torch.float32)
|
||||
scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps)
|
||||
hidden_states = torch.matmul(value, scores.to(value.dtype))
|
||||
return hidden_states
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
return self.processor(self, hidden_states)
|
||||
|
||||
|
||||
class EfficientViTBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -21,7 +21,8 @@ import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..attention_processor import Attention, SpatialNorm
|
||||
from ..attention import Attention
|
||||
from ..attention_processor import SpatialNorm
|
||||
from ..autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
|
||||
from ..downsampling import Downsample2D
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
|
||||
@@ -24,7 +24,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..activations import get_activation
|
||||
from ..attention_processor import Attention
|
||||
from ..attention import Attention
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution
|
||||
|
||||
@@ -23,7 +23,8 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..activations import get_activation
|
||||
from ..attention_processor import Attention, MochiVaeAttnProcessor2_0
|
||||
from ..attention import Attention
|
||||
from ..attention_processor import MochiVaeAttnProcessor2_0
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .autoencoder_kl_cogvideox import CogVideoXCausalConv3d
|
||||
|
||||
@@ -22,7 +22,8 @@ import torch.nn as nn
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
|
||||
from ..attention import Attention
|
||||
from ..attention_processor import AttentionProcessor, FusedJointAttnProcessor2_0
|
||||
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
|
||||
@@ -21,7 +21,7 @@ from torch import nn
|
||||
|
||||
from ..utils import deprecate
|
||||
from .activations import FP32SiLU, get_activation
|
||||
from .attention_processor import Attention
|
||||
from .attention import Attention
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
|
||||
@@ -23,11 +23,9 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import Attention, AttentionMixin
|
||||
from ..attention_processor import (
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
AuraFlowAttnProcessor2_0,
|
||||
FusedAuraFlowAttnProcessor2_0,
|
||||
)
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
@@ -267,7 +265,7 @@ class AuraFlowJointTransformerBlock(nn.Module):
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
|
||||
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, AttentionMixin):
|
||||
r"""
|
||||
A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/).
|
||||
|
||||
@@ -357,105 +355,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedAuraFlowAttnProcessor2_0
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
||||
are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
self.original_attn_processors = None
|
||||
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
self.original_attn_processors = self.attn_processors
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
self.set_attn_processor(FusedAuraFlowAttnProcessor2_0())
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
# Using methods from AttentionMixin
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -22,12 +22,9 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import Attention
|
||||
from ..attention import AttentionMixin
|
||||
from ..attention_processor import (
|
||||
AttentionModuleMixin,
|
||||
AttentionProcessor,
|
||||
CogVideoXAttnProcessor2_0,
|
||||
FusedCogVideoXAttnProcessor2_0,
|
||||
)
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
|
||||
@@ -103,7 +100,7 @@ class BaseCogVideoXAttnProcessor:
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
attn: CogVideoXAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
@@ -260,7 +257,7 @@ class CogVideoXBlock(nn.Module):
|
||||
# 1. Self Attention
|
||||
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
||||
|
||||
self.attn1 = Attention(
|
||||
self.attn1 = CogVideoXAttention(
|
||||
query_dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
@@ -268,7 +265,6 @@ class CogVideoXBlock(nn.Module):
|
||||
eps=1e-6,
|
||||
bias=attention_bias,
|
||||
out_bias=attention_out_bias,
|
||||
processor=CogVideoXAttnProcessor2_0(),
|
||||
)
|
||||
|
||||
# 2. Feed Forward
|
||||
@@ -325,7 +321,7 @@ class CogVideoXBlock(nn.Module):
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
|
||||
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin, AttentionMixin):
|
||||
"""
|
||||
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
|
||||
|
||||
@@ -499,105 +495,9 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cac
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
# Using inherited methods from AttentionMixin
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
||||
are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
self.original_attn_processors = None
|
||||
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
self.original_attn_processors = self.attn_processors
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
# Using inherited methods from AttentionMixin
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -22,8 +22,8 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import Attention
|
||||
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0
|
||||
from ..attention import Attention, AttentionMixin
|
||||
from ..attention_processor import CogVideoXAttnProcessor2_0
|
||||
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
@@ -349,7 +349,7 @@ class ConsisIDBlock(nn.Module):
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, AttentionMixin):
|
||||
"""
|
||||
A Transformer model for video-like data in [ConsisID](https://github.com/PKU-YuanGroup/ConsisID).
|
||||
|
||||
@@ -621,65 +621,7 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
]
|
||||
)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
# Using methods from AttentionMixin
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -19,16 +19,17 @@ from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from .modeling_common BasicTransformerBlock
|
||||
from ..attention import AttentionMixin
|
||||
from ..embeddings import PatchEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .modeling_common import BasicTransformerBlock
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class DiTTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
class DiTTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
|
||||
r"""
|
||||
A 2D Transformer model as introduced in DiT (https://arxiv.org/abs/2212.09748).
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -19,7 +19,8 @@ from torch import nn
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention_processor import Attention, AttentionProcessor, FusedHunyuanAttnProcessor2_0, HunyuanAttnProcessor2_0
|
||||
from ..attention import Attention, AttentionMixin
|
||||
from ..attention_processor import HunyuanAttnProcessor2_0
|
||||
from ..embeddings import (
|
||||
HunyuanCombinedTimestepTextSizeStyleEmbedding,
|
||||
PatchEmbed,
|
||||
@@ -200,7 +201,7 @@ class HunyuanDiTBlock(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
|
||||
class HunyuanDiT2DModel(ModelMixin, ConfigMixin, AttentionMixin):
|
||||
"""
|
||||
HunYuanDiT: Diffusion model with a Transformer backbone.
|
||||
|
||||
@@ -318,105 +319,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
|
||||
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedHunyuanAttnProcessor2_0
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
||||
are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
self.original_attn_processors = None
|
||||
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
self.original_attn_processors = self.attn_processors
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
self.set_attn_processor(FusedHunyuanAttnProcessor2_0())
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
# Using methods from AttentionMixin
|
||||
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
|
||||
@@ -19,15 +19,16 @@ from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models.embeddings import PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
|
||||
from .modeling_common BasicTransformerBlock
|
||||
from ..attention import AttentionMixin
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import PatchEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormSingle
|
||||
from .modeling_common import BasicTransformerBlock
|
||||
|
||||
|
||||
class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
|
||||
class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin, AttentionMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
"""
|
||||
|
||||
@@ -17,12 +17,12 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..utils import deprecate, logging
|
||||
from ..utils.torch_utils import maybe_allow_in_graph
|
||||
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
|
||||
from .attention_processor import Attention, JointAttnProcessor2_0
|
||||
from .embeddings import SinusoidalPositionalEmbedding
|
||||
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
|
||||
from ...utils import deprecate, logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
|
||||
from ..attention_processor import Attention, JointAttnProcessor2_0
|
||||
from ..embeddings import SinusoidalPositionalEmbedding
|
||||
from ..normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@@ -11,25 +11,26 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from .modeling_common BasicTransformerBlock
|
||||
from ..attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0
|
||||
from ..attention import AttentionMixin
|
||||
from ..attention_processor import AttnProcessor
|
||||
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormSingle
|
||||
from .modeling_common import BasicTransformerBlock
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
class PixArtTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
|
||||
r"""
|
||||
A 2D Transformer model as introduced in PixArt family of models (https://arxiv.org/abs/2310.00426,
|
||||
https://arxiv.org/abs/2403.04692).
|
||||
@@ -184,65 +185,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
in_features=self.config.caption_channels, hidden_size=self.inner_dim
|
||||
)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
# Using inherited method from AttentionMixin
|
||||
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
@@ -252,45 +195,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
self.set_attn_processor(AttnProcessor())
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
||||
are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
self.original_attn_processors = None
|
||||
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
self.original_attn_processors = self.attn_processors
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
self.set_attn_processor(FusedAttnProcessor2_0())
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
# Using inherited methods from AttentionMixin
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -8,16 +8,16 @@ from torch import nn
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
|
||||
from ...utils import BaseOutput
|
||||
from .modeling_common BasicTransformerBlock
|
||||
from ..attention import AttentionMixin
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .modeling_common import BasicTransformerBlock
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -33,7 +33,7 @@ class PriorTransformerOutput(BaseOutput):
|
||||
predicted_image_embedding: torch.Tensor
|
||||
|
||||
|
||||
class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
|
||||
class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin, AttentionMixin):
|
||||
"""
|
||||
A Prior Transformer model.
|
||||
|
||||
@@ -166,65 +166,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
|
||||
self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim))
|
||||
self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim))
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
# Using inherited methods from AttentionMixin
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
|
||||
@@ -21,10 +21,9 @@ from torch import nn
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import Attention, AttentionMixin
|
||||
from ..attention_processor import (
|
||||
Attention,
|
||||
AttentionModuleMixin,
|
||||
AttentionProcessor,
|
||||
SanaLinearAttnProcessor2_0,
|
||||
)
|
||||
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, TimestepEmbedding, Timesteps
|
||||
@@ -388,7 +387,7 @@ class SanaTransformerBlock(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, AttentionMixin):
|
||||
r"""
|
||||
A 2D Transformer model introduced in [Sana](https://huggingface.co/papers/2410.10629) family of models.
|
||||
|
||||
@@ -513,65 +512,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
# Using methods from AttentionMixin
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -21,10 +21,8 @@ import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models.attention import FeedForward
|
||||
from ...models.attention import Attention, AttentionMixin, FeedForward
|
||||
from ...models.attention_processor import (
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
StableAudioAttnProcessor2_0,
|
||||
)
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
@@ -187,7 +185,7 @@ class StableAudioDiTBlock(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class StableAudioDiTModel(ModelMixin, ConfigMixin):
|
||||
class StableAudioDiTModel(ModelMixin, ConfigMixin, AttentionMixin):
|
||||
"""
|
||||
The Diffusion Transformer model introduced in Stable Audio.
|
||||
|
||||
@@ -279,65 +277,7 @@ class StableAudioDiTModel(ModelMixin, ConfigMixin):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
# Using methods from AttentionMixin
|
||||
|
||||
# Copied from diffusers.models.transformers.hunyuan_transformer_2d.HunyuanDiT2DModel.set_default_attn_processor with Hunyuan->StableAudio
|
||||
def set_default_attn_processor(self):
|
||||
|
||||
@@ -19,11 +19,12 @@ from torch import nn
|
||||
|
||||
from ...configuration_utils import LegacyConfigMixin, register_to_config
|
||||
from ...utils import deprecate, logging
|
||||
from .modeling_common BasicTransformerBlock
|
||||
from ..attention import AttentionMixin
|
||||
from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import LegacyModelMixin
|
||||
from ..normalization import AdaLayerNormSingle
|
||||
from .modeling_common import BasicTransformerBlock
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -36,7 +37,7 @@ class Transformer2DModelOutput(Transformer2DModelOutput):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
|
||||
class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin, AttentionMixin):
|
||||
"""
|
||||
A 2D Transformer model for image-like data.
|
||||
|
||||
|
||||
@@ -13,16 +13,14 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Dict, Union
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models.attention import FeedForward
|
||||
from ...models.attention import Attention, AttentionMixin, FeedForward
|
||||
from ...models.attention_processor import (
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
CogVideoXAttnProcessor2_0,
|
||||
)
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
@@ -130,7 +128,7 @@ class CogView3PlusTransformerBlock(nn.Module):
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
|
||||
r"""
|
||||
The Transformer model introduced in [CogView3: Finer and Faster Text-to-Image Generation via Relay
|
||||
Diffusion](https://huggingface.co/papers/2403.05121).
|
||||
@@ -229,65 +227,7 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
# Using methods from AttentionMixin
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -24,12 +24,13 @@ import torch.nn.functional as F
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...models.attention import FeedForward
|
||||
from ...models.attention_processor import AttentionModuleMixin, AttentionProcessor
|
||||
from ...models.attention_processor import AttentionModuleMixin
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, RMSNorm
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.import_utils import is_torch_npu_available, is_torch_xla_available, is_torch_xla_version
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionMixin
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
@@ -592,7 +593,7 @@ class FluxTransformerBlock(nn.Module):
|
||||
|
||||
|
||||
class FluxTransformer2DModel(
|
||||
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
|
||||
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin, AttentionMixin
|
||||
):
|
||||
"""
|
||||
The Transformer model introduced in Flux.
|
||||
@@ -687,97 +688,9 @@ class FluxTransformer2DModel(
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
# Using inherited methods from AttentionMixin
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
||||
are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
"""
|
||||
self.original_attn_processors = None
|
||||
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
self.original_attn_processors = self.attn_processors
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, AttentionModuleMixin):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
# Using inherited methods from AttentionMixin
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -23,7 +23,7 @@ from diffusers.loaders import FromOriginalModelMixin
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention_processor import Attention, AttentionProcessor
|
||||
from ..attention import Attention, AttentionMixin
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import (
|
||||
CombinedTimestepTextProjEmbeddings,
|
||||
@@ -819,7 +819,7 @@ class HunyuanVideoTokenReplaceTransformerBlock(nn.Module):
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
|
||||
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin):
|
||||
r"""
|
||||
A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
|
||||
|
||||
@@ -962,65 +962,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
# Using methods from AttentionMixin
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -20,15 +20,13 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin
|
||||
from ...models.attention import FeedForward, JointTransformerBlock
|
||||
from ...models.attention_processor import (
|
||||
Attention,
|
||||
AttentionModuleMixin,
|
||||
AttentionProcessor,
|
||||
FusedJointAttnProcessor2_0,
|
||||
)
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionMixin
|
||||
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
|
||||
@@ -280,7 +278,7 @@ class SD3SingleTransformerBlock(nn.Module):
|
||||
|
||||
|
||||
class SD3Transformer2DModel(
|
||||
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin
|
||||
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin, AttentionMixin
|
||||
):
|
||||
"""
|
||||
The Transformer model introduced in [Stable Diffusion 3](https://huggingface.co/papers/2403.03206).
|
||||
@@ -416,105 +414,9 @@ class SD3Transformer2DModel(
|
||||
for module in self.children():
|
||||
fn_recursive_feed_forward(module, None, 0)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
# Using inherited methods from AttentionMixin
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
||||
are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
self.original_attn_processors = None
|
||||
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
self.original_attn_processors = self.attn_processors
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
self.set_attn_processor(FusedJointAttnProcessor2_0())
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
# Using inherited methods from AttentionMixin
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -19,10 +19,11 @@ from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import BaseOutput
|
||||
from .modeling_common BasicTransformerBlock, TemporalBasicTransformerBlock
|
||||
from ..attention import AttentionMixin
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..resnet import AlphaBlender
|
||||
from .modeling_common import BasicTransformerBlock, TemporalBasicTransformerBlock
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -38,7 +39,7 @@ class TransformerTemporalModelOutput(BaseOutput):
|
||||
sample: torch.Tensor
|
||||
|
||||
|
||||
class TransformerTemporalModel(ModelMixin, ConfigMixin):
|
||||
class TransformerTemporalModel(ModelMixin, ConfigMixin, AttentionMixin):
|
||||
"""
|
||||
A Transformer model for video-like data.
|
||||
|
||||
@@ -202,7 +203,7 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
|
||||
return TransformerTemporalModelOutput(sample=output)
|
||||
|
||||
|
||||
class TransformerSpatioTemporalModel(nn.Module):
|
||||
class TransformerSpatioTemporalModel(nn.Module, AttentionMixin):
|
||||
"""
|
||||
A Transformer model for video-like data.
|
||||
|
||||
|
||||
@@ -21,7 +21,8 @@ from torch import nn
|
||||
from ...utils import deprecate, logging
|
||||
from ...utils.torch_utils import apply_freeu
|
||||
from ..activations import get_activation
|
||||
from ..attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
|
||||
from ..attention import Attention
|
||||
from ..attention_processor import AttnAddedKVProcessor, AttnAddedKVProcessor2_0
|
||||
from ..normalization import AdaGroupNorm
|
||||
from ..resnet import (
|
||||
Downsample2D,
|
||||
|
||||
@@ -21,7 +21,8 @@ from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import BaseOutput, logging
|
||||
from ..attention_processor import Attention, AttentionProcessor, AttnProcessor
|
||||
from ..attention import Attention
|
||||
from ..attention_processor import AttentionProcessor, AttnProcessor
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ import torch.nn as nn
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin
|
||||
from ...utils import BaseOutput
|
||||
from ..attention_processor import Attention
|
||||
from ..attention import Attention
|
||||
from ..modeling_utils import ModelMixin
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user