1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[core] support attention backends for LTX (#12021)

* support attention backends for lTX

* Apply suggestions from code review

Co-authored-by: Aryan <aryan@huggingface.co>

* reviewer feedback.

---------

Co-authored-by: Aryan <aryan@huggingface.co>
This commit is contained in:
Sayak Paul
2025-07-30 16:35:11 +05:30
committed by GitHub
parent 843e3f9346
commit c052791b5f

View File

@@ -1,4 +1,4 @@
# Copyright 2025 The Genmo team and The HuggingFace Team.
# Copyright 2025 The Lightricks team and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,19 +13,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import math
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, deprecate, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
from ..embeddings import PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
@@ -37,20 +37,30 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class LTXVideoAttentionProcessor2_0:
def __new__(cls, *args, **kwargs):
deprecation_message = "`LTXVideoAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `LTXVideoAttnProcessor`"
deprecate("LTXVideoAttentionProcessor2_0", "1.0.0", deprecation_message)
return LTXVideoAttnProcessor(*args, **kwargs)
class LTXVideoAttnProcessor:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the LTX model. It applies a normalization layer and rotary embedding on the query and key vector.
Processor for implementing attention (SDPA is used by default if you're using PyTorch 2.0). This is used in the LTX
model. It applies a normalization layer and rotary embedding on the query and key vector.
"""
_attention_backend = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"LTXVideoAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
if is_torch_version("<", "2.0"):
raise ValueError(
"LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation."
)
def __call__(
self,
attn: Attention,
attn: "LTXAttention",
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
@@ -78,14 +88,20 @@ class LTXVideoAttentionProcessor2_0:
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
query = query.unflatten(2, (attn.heads, -1))
key = key.unflatten(2, (attn.heads, -1))
value = value.unflatten(2, (attn.heads, -1))
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.to_out[0](hidden_states)
@@ -93,6 +109,70 @@ class LTXVideoAttentionProcessor2_0:
return hidden_states
class LTXAttention(torch.nn.Module, AttentionModuleMixin):
_default_processor_cls = LTXVideoAttnProcessor
_available_processors = [LTXVideoAttnProcessor]
def __init__(
self,
query_dim: int,
heads: int = 8,
kv_heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = True,
cross_attention_dim: Optional[int] = None,
out_bias: bool = True,
qk_norm: str = "rms_norm_across_heads",
processor=None,
):
super().__init__()
if qk_norm != "rms_norm_across_heads":
raise NotImplementedError("Only 'rms_norm_across_heads' is supported as a valid value for `qk_norm`.")
self.head_dim = dim_head
self.inner_dim = dim_head * heads
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
self.query_dim = query_dim
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.use_bias = bias
self.dropout = dropout
self.out_dim = query_dim
self.heads = heads
norm_eps = 1e-5
norm_elementwise_affine = True
self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
self.norm_k = torch.nn.RMSNorm(dim_head * kv_heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
self.to_v = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
self.to_out = torch.nn.ModuleList([])
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(torch.nn.Dropout(dropout))
if processor is None:
processor = self._default_processor_cls()
self.set_processor(processor)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters]
if len(unused_kwargs) > 0:
logger.warning(
f"attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
)
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
class LTXVideoRotaryPosEmbed(nn.Module):
def __init__(
self,
@@ -231,7 +311,7 @@ class LTXVideoTransformerBlock(nn.Module):
super().__init__()
self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
self.attn1 = Attention(
self.attn1 = LTXAttention(
query_dim=dim,
heads=num_attention_heads,
kv_heads=num_attention_heads,
@@ -240,11 +320,10 @@ class LTXVideoTransformerBlock(nn.Module):
cross_attention_dim=None,
out_bias=attention_out_bias,
qk_norm=qk_norm,
processor=LTXVideoAttentionProcessor2_0(),
)
self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
self.attn2 = Attention(
self.attn2 = LTXAttention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
@@ -253,7 +332,6 @@ class LTXVideoTransformerBlock(nn.Module):
bias=attention_bias,
out_bias=attention_out_bias,
qk_norm=qk_norm,
processor=LTXVideoAttentionProcessor2_0(),
)
self.ff = FeedForward(dim, activation_fn=activation_fn)
@@ -299,7 +377,9 @@ class LTXVideoTransformerBlock(nn.Module):
@maybe_allow_in_graph
class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin):
class LTXVideoTransformer3DModel(
ModelMixin, ConfigMixin, AttentionMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin
):
r"""
A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video).