mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[core] support attention backends for LTX (#12021)
* support attention backends for lTX * Apply suggestions from code review Co-authored-by: Aryan <aryan@huggingface.co> * reviewer feedback. --------- Co-authored-by: Aryan <aryan@huggingface.co>
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The Genmo team and The HuggingFace Team.
|
||||
# Copyright 2025 The Lightricks team and The HuggingFace Team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -13,19 +13,19 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import math
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import PixArtAlphaTextProjection
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
@@ -37,20 +37,30 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class LTXVideoAttentionProcessor2_0:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "`LTXVideoAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `LTXVideoAttnProcessor`"
|
||||
deprecate("LTXVideoAttentionProcessor2_0", "1.0.0", deprecation_message)
|
||||
|
||||
return LTXVideoAttnProcessor(*args, **kwargs)
|
||||
|
||||
|
||||
class LTXVideoAttnProcessor:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
||||
used in the LTX model. It applies a normalization layer and rotary embedding on the query and key vector.
|
||||
Processor for implementing attention (SDPA is used by default if you're using PyTorch 2.0). This is used in the LTX
|
||||
model. It applies a normalization layer and rotary embedding on the query and key vector.
|
||||
"""
|
||||
|
||||
_attention_backend = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"LTXVideoAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
if is_torch_version("<", "2.0"):
|
||||
raise ValueError(
|
||||
"LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
attn: "LTXAttention",
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
@@ -78,14 +88,20 @@ class LTXVideoAttentionProcessor2_0:
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
query = query.unflatten(2, (attn.heads, -1))
|
||||
key = key.unflatten(2, (attn.heads, -1))
|
||||
value = value.unflatten(2, (attn.heads, -1))
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
@@ -93,6 +109,70 @@ class LTXVideoAttentionProcessor2_0:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
_default_processor_cls = LTXVideoAttnProcessor
|
||||
_available_processors = [LTXVideoAttnProcessor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
heads: int = 8,
|
||||
kv_heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
bias: bool = True,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
out_bias: bool = True,
|
||||
qk_norm: str = "rms_norm_across_heads",
|
||||
processor=None,
|
||||
):
|
||||
super().__init__()
|
||||
if qk_norm != "rms_norm_across_heads":
|
||||
raise NotImplementedError("Only 'rms_norm_across_heads' is supported as a valid value for `qk_norm`.")
|
||||
|
||||
self.head_dim = dim_head
|
||||
self.inner_dim = dim_head * heads
|
||||
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
|
||||
self.query_dim = query_dim
|
||||
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
||||
self.use_bias = bias
|
||||
self.dropout = dropout
|
||||
self.out_dim = query_dim
|
||||
self.heads = heads
|
||||
|
||||
norm_eps = 1e-5
|
||||
norm_elementwise_affine = True
|
||||
self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
|
||||
self.norm_k = torch.nn.RMSNorm(dim_head * kv_heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
|
||||
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_k = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
|
||||
self.to_v = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
|
||||
self.to_out = torch.nn.ModuleList([])
|
||||
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
||||
self.to_out.append(torch.nn.Dropout(dropout))
|
||||
|
||||
if processor is None:
|
||||
processor = self._default_processor_cls()
|
||||
self.set_processor(processor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
||||
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters]
|
||||
if len(unused_kwargs) > 0:
|
||||
logger.warning(
|
||||
f"attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
|
||||
)
|
||||
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
|
||||
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
|
||||
|
||||
|
||||
class LTXVideoRotaryPosEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -231,7 +311,7 @@ class LTXVideoTransformerBlock(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.attn1 = Attention(
|
||||
self.attn1 = LTXAttention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
kv_heads=num_attention_heads,
|
||||
@@ -240,11 +320,10 @@ class LTXVideoTransformerBlock(nn.Module):
|
||||
cross_attention_dim=None,
|
||||
out_bias=attention_out_bias,
|
||||
qk_norm=qk_norm,
|
||||
processor=LTXVideoAttentionProcessor2_0(),
|
||||
)
|
||||
|
||||
self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.attn2 = Attention(
|
||||
self.attn2 = LTXAttention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
heads=num_attention_heads,
|
||||
@@ -253,7 +332,6 @@ class LTXVideoTransformerBlock(nn.Module):
|
||||
bias=attention_bias,
|
||||
out_bias=attention_out_bias,
|
||||
qk_norm=qk_norm,
|
||||
processor=LTXVideoAttentionProcessor2_0(),
|
||||
)
|
||||
|
||||
self.ff = FeedForward(dim, activation_fn=activation_fn)
|
||||
@@ -299,7 +377,9 @@ class LTXVideoTransformerBlock(nn.Module):
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin):
|
||||
class LTXVideoTransformer3DModel(
|
||||
ModelMixin, ConfigMixin, AttentionMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin
|
||||
):
|
||||
r"""
|
||||
A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
|
||||
|
||||
|
||||
Reference in New Issue
Block a user