mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -46,6 +46,63 @@ else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
class AttnProcessorMixin:
|
||||
"""Attention processor used typically in processing Aura Flow."""
|
||||
|
||||
def _get_projections(self, attn, hidden_states, encoder_hidden_states=None):
|
||||
"""Get projections using standard separate projection matrices."""
|
||||
# Standard separate projections
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
# Handle encoder projections if present
|
||||
encoder_projections = None
|
||||
if encoder_hidden_states is not None and hasattr(attn, "add_q_proj"):
|
||||
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
||||
encoder_projections = (encoder_query, encoder_key, encoder_value)
|
||||
|
||||
return query, key, value, encoder_projections
|
||||
|
||||
def _get_fused_projections(self, attn, hidden_states, encoder_hidden_states=None):
|
||||
"""Get projections using fused QKV projection matrices."""
|
||||
# Fused QKV projection
|
||||
qkv = attn.to_qkv(hidden_states)
|
||||
split_size = qkv.shape[-1] // 3
|
||||
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||
|
||||
# Handle encoder projections if present
|
||||
encoder_projections = None
|
||||
if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
|
||||
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
|
||||
split_size = encoder_qkv.shape[-1] // 3
|
||||
encoder_query, encoder_key, encoder_value = torch.split(encoder_qkv, split_size, dim=-1)
|
||||
encoder_projections = (encoder_query, encoder_key, encoder_value)
|
||||
|
||||
return query, key, value, encoder_projections
|
||||
|
||||
def get_projections(self, attn, hidden_states, encoder_hidden_states=None):
|
||||
"""Public method to get projections based on whether we're using fused mode or not."""
|
||||
if self.is_fused and hasattr(attn, "to_qkv"):
|
||||
return self._get_fused_projections(attn, hidden_states, encoder_hidden_states)
|
||||
|
||||
return self._get_projections(attn, hidden_states, encoder_hidden_states)
|
||||
|
||||
def attention_fn(self, query, key, value, scale=None, attention_mask=None):
|
||||
"""Computes the attention. Can be overridden by hardware-specific implementations."""
|
||||
return F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, scale=scale, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
|
||||
class Attention(nn.Module, AttentionModuleMixin):
|
||||
default_processor_class = AttnProcessorSDPA
|
||||
_available_processors = []
|
||||
@@ -1292,99 +1349,6 @@ class AllegroAttnProcessorSDPA:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AuraFlowAttnProcessorSDPA:
|
||||
"""Attention processor used typically in processing Aura Flow."""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
|
||||
raise ImportError(
|
||||
"AuraFlowAttnProcessorSDPA requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
# `sample` projections.
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
# `context` projections.
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
# Reshape.
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim)
|
||||
|
||||
# Apply QK norm.
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Concatenate the projections.
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
)
|
||||
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim)
|
||||
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
)
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)
|
||||
|
||||
query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
||||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
||||
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
# Attention.
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# Split the attention outputs.
|
||||
if encoder_hidden_states is not None:
|
||||
hidden_states, encoder_hidden_states = (
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :],
|
||||
hidden_states[:, : encoder_hidden_states.shape[1]],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FusedAuraFlowAttnProcessorSDPA:
|
||||
"""Attention processor used typically in processing Aura Flow with fused projections."""
|
||||
|
||||
@@ -2335,104 +2299,6 @@ class StableAudioAttnProcessorSDPA:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanAttnProcessorSDPA:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
||||
used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on query and key vector.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("AttnProcessorSDPA requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
from .embeddings import apply_rotary_emb
|
||||
|
||||
residual = hidden_states
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if image_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
if not attn.is_cross_attention:
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FusedHunyuanAttnProcessorSDPA:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0) with fused
|
||||
|
||||
@@ -12,13 +12,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from huggingface_hub.utils import EntryNotFoundError, validate_hf_hub_args
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates
|
||||
|
||||
|
||||
class AutoModel(ConfigMixin):
|
||||
@@ -153,17 +153,39 @@ class AutoModel(ConfigMixin):
|
||||
"token": token,
|
||||
"local_files_only": local_files_only,
|
||||
"revision": revision,
|
||||
"subfolder": subfolder,
|
||||
}
|
||||
|
||||
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
|
||||
orig_class_name = config["_class_name"]
|
||||
library = None
|
||||
orig_class_name = None
|
||||
from diffusers import pipelines
|
||||
|
||||
library = importlib.import_module("diffusers")
|
||||
# Always attempt to fetch model_index.json first
|
||||
try:
|
||||
cls.config_name = "model_index.json"
|
||||
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
|
||||
|
||||
model_cls = getattr(library, orig_class_name, None)
|
||||
if model_cls is None:
|
||||
raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.")
|
||||
if subfolder is not None and subfolder in config:
|
||||
library, orig_class_name = config[subfolder]
|
||||
|
||||
except (OSError, EntryNotFoundError) as e:
|
||||
logger.debug(e)
|
||||
|
||||
# Unable to load from model_index.json so fallback to loading from config
|
||||
if library is None and orig_class_name is None:
|
||||
cls.config_name = "config.json"
|
||||
load_config_kwargs.update({"subfolder": subfolder})
|
||||
|
||||
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
|
||||
orig_class_name = config["_class_name"]
|
||||
library = "diffusers"
|
||||
|
||||
model_cls, _ = get_class_obj_and_candidates(
|
||||
library_name=library,
|
||||
class_name=orig_class_name,
|
||||
importable_classes=ALL_IMPORTABLE_CLASSES,
|
||||
pipelines=pipelines,
|
||||
is_pipeline_module=hasattr(pipelines, library),
|
||||
)
|
||||
|
||||
kwargs = {**load_config_kwargs, **kwargs}
|
||||
return model_cls.from_pretrained(pretrained_model_or_path, **kwargs)
|
||||
|
||||
@@ -23,8 +23,9 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import Attention, AttentionMixin
|
||||
from ..attention import Attention, AttentionMixin, AttentionModuleMixin
|
||||
from ..attention_processor import (
|
||||
AttnProcessorMixin,
|
||||
AuraFlowAttnProcessor2_0,
|
||||
)
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
@@ -133,6 +134,98 @@ class AuraFlowPreFinalBlock(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class AuraFlowAttnProcessorSDPA(AttnProcessorMixin):
|
||||
"""Attention processor used typically in processing Aura Flow."""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
|
||||
raise ImportError(
|
||||
"AuraFlowAttnProcessorSDPA requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: AttentionModuleMixin,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
batch_size = hidden_states.shape[0]
|
||||
query, key, value, encoder_projections = self.get_projections(attn, hidden_states, encoder_hidden_states)
|
||||
|
||||
# `context` projections.
|
||||
if encoder_projections is not None:
|
||||
encoder_hidden_states_query_proj, encoder_hidden_states_key_proj, encoder_hidden_states_value_proj = (
|
||||
encoder_projections
|
||||
)
|
||||
|
||||
# Reshape.
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim)
|
||||
|
||||
# Apply QK norm.
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Concatenate the projections.
|
||||
if encoder_projections is not None:
|
||||
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
)
|
||||
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim)
|
||||
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
)
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)
|
||||
|
||||
query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
||||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
||||
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
# Attention.
|
||||
hidden_states = self.attention_fn(query, key, value, scale=attn.scale)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# Split the attention outputs.
|
||||
if encoder_hidden_states is not None:
|
||||
hidden_states, encoder_hidden_states = (
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :],
|
||||
hidden_states[:, : encoder_hidden_states.shape[1]],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AuraFlowAttention(Attention):
|
||||
default_processor_cls = AuraFlowAttnProcessorSDPA
|
||||
_available_processors = [AuraFlowAttnProcessorSDPA]
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class AuraFlowSingleTransformerBlock(nn.Module):
|
||||
"""Similar to `AuraFlowJointTransformerBlock` with a single DiT instead of an MMDiT."""
|
||||
@@ -143,7 +236,7 @@ class AuraFlowSingleTransformerBlock(nn.Module):
|
||||
self.norm1 = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm")
|
||||
|
||||
processor = AuraFlowAttnProcessor2_0()
|
||||
self.attn = Attention(
|
||||
self.attn = AuraFlowAttention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
dim_head=attention_head_dim,
|
||||
@@ -206,7 +299,7 @@ class AuraFlowJointTransformerBlock(nn.Module):
|
||||
self.norm1_context = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm")
|
||||
|
||||
processor = AuraFlowAttnProcessor2_0()
|
||||
self.attn = Attention(
|
||||
self.attn = AuraFlowAttention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
added_kv_proj_dim=dim,
|
||||
|
||||
@@ -22,10 +22,8 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionMixin
|
||||
from ..attention_processor import (
|
||||
AttentionModuleMixin,
|
||||
)
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin
|
||||
from ..attention_processor import Attention, AttnProcessorMixin
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
@@ -37,13 +35,13 @@ from .modeling_common import FeedForward
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class BaseCogVideoXAttnProcessor:
|
||||
class CogVideoXAttnProcessorSDPA(AttnProcessorMixin):
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
||||
query and key vectors, but does not include spatial normalization.
|
||||
"""
|
||||
|
||||
compatible_backends = []
|
||||
compatible_backends = ["cuda", "cpu", "xpu"]
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
@@ -51,56 +49,9 @@ class BaseCogVideoXAttnProcessor:
|
||||
"CogVideoXAttnProcessorSDPA requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
|
||||
def get_projections(self, attn, hidden_states, encoder_hidden_states=None):
|
||||
"""Public method to get projections based on whether we're using fused mode or not."""
|
||||
if self.is_fused and hasattr(attn, "to_qkv"):
|
||||
return self._get_fused_projections(attn, hidden_states, encoder_hidden_states)
|
||||
|
||||
return self._get_projections(attn, hidden_states, encoder_hidden_states)
|
||||
|
||||
def _get_projections(self, attn, hidden_states, encoder_hidden_states=None):
|
||||
"""Get projections using standard separate projection matrices."""
|
||||
# Standard separate projections
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
# Handle encoder projections if present
|
||||
encoder_projections = None
|
||||
if encoder_hidden_states is not None and hasattr(attn, "add_q_proj"):
|
||||
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
||||
encoder_projections = (encoder_query, encoder_key, encoder_value)
|
||||
|
||||
return query, key, value, encoder_projections
|
||||
|
||||
def _get_fused_projections(self, attn, hidden_states, encoder_hidden_states=None):
|
||||
"""Get projections using fused QKV projection matrices."""
|
||||
# Fused QKV projection
|
||||
qkv = attn.to_qkv(hidden_states)
|
||||
split_size = qkv.shape[-1] // 3
|
||||
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||
|
||||
# Handle encoder projections if present
|
||||
encoder_projections = None
|
||||
if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
|
||||
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
|
||||
split_size = encoder_qkv.shape[-1] // 3
|
||||
encoder_query, encoder_key, encoder_value = torch.split(encoder_qkv, split_size, dim=-1)
|
||||
encoder_projections = (encoder_query, encoder_key, encoder_value)
|
||||
|
||||
return query, key, value, encoder_projections
|
||||
|
||||
def _compute_attention(self, query, key, value, attention_mask=None):
|
||||
"""Computes the attention. Can be overridden by hardware-specific implementations."""
|
||||
return F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: CogVideoXAttention,
|
||||
attn: AttentionModuleMixin,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
@@ -138,7 +89,7 @@ class BaseCogVideoXAttnProcessor:
|
||||
if not attn.is_cross_attention:
|
||||
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
||||
|
||||
hidden_states = self._compute_attention(query, key, value, attention_mask=attention_mask)
|
||||
hidden_states = self.attention_fn(query, key, value, attention_mask=attention_mask)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
|
||||
# linear proj
|
||||
@@ -152,52 +103,10 @@ class BaseCogVideoXAttnProcessor:
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class CogVideoXAttnProcessorSDPA(BaseCogVideoXAttnProcessor):
|
||||
compatible_backends = ["cuda", "cpu", "xpu"]
|
||||
|
||||
|
||||
class CogVideoXAttention(nn.Module, AttentionModuleMixin):
|
||||
class CogVideoXAttention(Attention):
|
||||
default_processor_cls = CogVideoXAttnProcessorSDPA
|
||||
_available_processors = [CogVideoXAttnProcessorSDPA]
|
||||
|
||||
def __init__(
|
||||
self, query_dim, dim_head, heads, dropout=0.0, qk_norm=None, eps=1e-6, bias=False, out_bias=False
|
||||
) -> None:
|
||||
self.query_dim = query_dim
|
||||
self.out_dim = query_dim
|
||||
self.inner_dim = dim_head * heads
|
||||
self.use_bias = bias
|
||||
|
||||
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
|
||||
if qk_norm is None:
|
||||
self.norm_q = None
|
||||
self.norm_k = None
|
||||
elif qk_norm == "layer_norm":
|
||||
self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
|
||||
self.to_out = nn.ModuleList([])
|
||||
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
||||
self.to_out.append(nn.Dropout(dropout))
|
||||
|
||||
self.set_processor(self.default_processor_cls())
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return self.processor(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class CogVideoXBlock(nn.Module):
|
||||
|
||||
@@ -22,8 +22,8 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import Attention, AttentionMixin
|
||||
from ..attention_processor import CogVideoXAttnProcessor2_0
|
||||
from ..attention import Attention
|
||||
from ..attention_processor import CogVideoXAttnProcessor2_0, CogVideoXAttnProcessorSDPA
|
||||
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
@@ -229,6 +229,11 @@ class PerceiverCrossAttention(nn.Module):
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class ConsisIDAttention(Attention):
|
||||
default_processor_cls = CogVideoXAttnProcessorSDPA
|
||||
_available_processors = [CogVideoXAttnProcessorSDPA]
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class ConsisIDBlock(nn.Module):
|
||||
r"""
|
||||
@@ -287,7 +292,7 @@ class ConsisIDBlock(nn.Module):
|
||||
# 1. Self Attention
|
||||
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
||||
|
||||
self.attn1 = Attention(
|
||||
self.attn1 = ConsisIDAttention(
|
||||
query_dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
|
||||
@@ -19,8 +19,8 @@ from torch import nn
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import Attention, AttentionMixin
|
||||
from ..attention_processor import HunyuanAttnProcessor2_0
|
||||
from ..attention import Attention, AttentionMixin, AttnProcessorMixin
|
||||
from ..attention_processor import HunyuanAttnProcessor2_0, HunyuanAttnProcessorSDPA
|
||||
from ..embeddings import (
|
||||
HunyuanCombinedTimestepTextSizeStyleEmbedding,
|
||||
PatchEmbed,
|
||||
@@ -56,6 +56,98 @@ class AdaLayerNormShift(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class HunyuanAttnProcessorSDPA(AttnProcessorMixin):
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
||||
used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on query and key vector.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("AttnProcessorSDPA requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: AttentionModuleMixin,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
from .embeddings import apply_rotary_emb
|
||||
|
||||
residual = hidden_states
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query, key, value = self.get_projections(attn, hidden_states, encoder_hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if image_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
if not attn.is_cross_attention:
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = self.attention_fn(query, key, value, attn_mask=attention_mask)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanDiTAttention(Attention):
|
||||
default_processor_cls = HunyuanAttnProcessorSDPA
|
||||
_available_processors = [HunyuanAttnProcessorSDPA]
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class HunyuanDiTBlock(nn.Module):
|
||||
r"""
|
||||
@@ -111,7 +203,7 @@ class HunyuanDiTBlock(nn.Module):
|
||||
# 1. Self-Attn
|
||||
self.norm1 = AdaLayerNormShift(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
||||
|
||||
self.attn1 = Attention(
|
||||
self.attn1 = HunyuanDiTAttention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
dim_head=dim // num_attention_heads,
|
||||
@@ -119,7 +211,6 @@ class HunyuanDiTBlock(nn.Module):
|
||||
qk_norm="layer_norm" if qk_norm else None,
|
||||
eps=1e-6,
|
||||
bias=True,
|
||||
processor=HunyuanAttnProcessor2_0(),
|
||||
)
|
||||
|
||||
# 2. Cross-Attn
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch.nn as nn
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from ..attention import LuminaFeedForward
|
||||
from ..attention_processor import Attention, LuminaAttnProcessor2_0
|
||||
from ..attention_processor import Attention, AttentionMixin, AttnProcessorMixin
|
||||
from ..embeddings import (
|
||||
LuminaCombinedTimestepCaptionEmbedding,
|
||||
LuminaPatchEmbed,
|
||||
@@ -33,6 +33,101 @@ from ..normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNor
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class LuminaAttnProcessorSDPA(AttnProcessorMixin):
|
||||
compatible_backends = ["cuda", "cpu", "xpu"]
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("AttnProcessorSDPA requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
query_rotary_emb: Optional[torch.Tensor] = None,
|
||||
key_rotary_emb: Optional[torch.Tensor] = None,
|
||||
base_sequence_length: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
from .embeddings import apply_rotary_emb
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
query, key, value, _ = self.get_projections(attn, hidden_states, encoder_hidden_states)
|
||||
|
||||
query_dim = query.shape[-1]
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = query_dim // attn.heads
|
||||
dtype = query.dtype
|
||||
|
||||
# Get key-value heads
|
||||
kv_heads = inner_dim // head_dim
|
||||
|
||||
# Apply Query-Key Norm if needed
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim)
|
||||
|
||||
key = key.view(batch_size, -1, kv_heads, head_dim)
|
||||
value = value.view(batch_size, -1, kv_heads, head_dim)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if query_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, query_rotary_emb, use_real=False)
|
||||
if key_rotary_emb is not None:
|
||||
key = apply_rotary_emb(key, key_rotary_emb, use_real=False)
|
||||
|
||||
query, key = query.to(dtype), key.to(dtype)
|
||||
|
||||
# Apply proportional attention if true
|
||||
if key_rotary_emb is None:
|
||||
softmax_scale = None
|
||||
else:
|
||||
if base_sequence_length is not None:
|
||||
softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
|
||||
else:
|
||||
softmax_scale = attn.scale
|
||||
|
||||
# perform Grouped-qurey Attention (GQA)
|
||||
n_rep = attn.heads // kv_heads
|
||||
if n_rep >= 1:
|
||||
key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
|
||||
attention_mask = attention_mask.expand(-1, attn.heads, sequence_length, -1)
|
||||
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, scale=softmax_scale
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).to(dtype)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LuminaNextAttention(Attention):
|
||||
default_processor_cls = LuminaAttnProcessorSDPA
|
||||
_available_processors = [LuminaAttnProcessorSDPA]
|
||||
|
||||
|
||||
class LuminaNextDiTBlock(nn.Module):
|
||||
"""
|
||||
A LuminaNextDiTBlock for LuminaNextDiT2DModel.
|
||||
@@ -68,7 +163,7 @@ class LuminaNextDiTBlock(nn.Module):
|
||||
self.gate = nn.Parameter(torch.zeros([num_attention_heads]))
|
||||
|
||||
# Self-attention
|
||||
self.attn1 = Attention(
|
||||
self.attn1 = LuminaNextAttention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
dim_head=dim // num_attention_heads,
|
||||
@@ -78,12 +173,11 @@ class LuminaNextDiTBlock(nn.Module):
|
||||
eps=1e-5,
|
||||
bias=False,
|
||||
out_bias=False,
|
||||
processor=LuminaAttnProcessor2_0(),
|
||||
)
|
||||
self.attn1.to_out = nn.Identity()
|
||||
|
||||
# Cross-attention
|
||||
self.attn2 = Attention(
|
||||
self.attn2 = LuminaNextAttention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
dim_head=dim // num_attention_heads,
|
||||
@@ -93,7 +187,6 @@ class LuminaNextDiTBlock(nn.Module):
|
||||
eps=1e-5,
|
||||
bias=False,
|
||||
out_bias=False,
|
||||
processor=LuminaAttnProcessor2_0(),
|
||||
)
|
||||
|
||||
self.feed_forward = LuminaFeedForward(
|
||||
@@ -175,7 +268,7 @@ class LuminaNextDiTBlock(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LuminaNextDiT2DModel(ModelMixin, ConfigMixin):
|
||||
class LuminaNextDiT2DModel(ModelMixin, ConfigMixin, AttentionMixin):
|
||||
"""
|
||||
LuminaNextDiT: Diffusion model with a Transformer backbone.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user