mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
refactor attention
This commit is contained in:
@@ -19,10 +19,6 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import BoolTensor, IntTensor, Tensor, nn
|
||||
from torch.nn.attention.flex_attention import (
|
||||
BlockMask,
|
||||
flex_attention,
|
||||
)
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
@@ -34,7 +30,7 @@ from ...utils import (
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionMixin, FeedForward
|
||||
from ..attention import AttentionMixin, FeedForward, AttentionModuleMixin
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import (
|
||||
TimestepEmbedding,
|
||||
@@ -43,6 +39,7 @@ from ..embeddings import (
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import FP32LayerNorm
|
||||
from ..attention_dispatch import dispatch_attention_fn, _CAN_USE_FLEX_ATTN
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@@ -149,7 +146,15 @@ def nablaT_v2(
|
||||
k: Tensor,
|
||||
sta: Tensor,
|
||||
thr: float = 0.9,
|
||||
) -> BlockMask:
|
||||
):
|
||||
if _CAN_USE_FLEX_ATTN:
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
else:
|
||||
raise ValueError("Nabla attention is not supported with this version of PyTorch")
|
||||
|
||||
q = q.transpose(1, 2).contiguous()
|
||||
k = k.transpose(1, 2).contiguous()
|
||||
|
||||
# Map estimation
|
||||
B, h, S, D = q.shape
|
||||
s1 = S // 64
|
||||
@@ -174,13 +179,6 @@ def nablaT_v2(
|
||||
)
|
||||
|
||||
|
||||
@torch.autocast(device_type="cuda", enabled=False)
|
||||
def apply_rotary(x, rope):
|
||||
x_ = x.reshape(*x.shape[:-1], -1, 1, 2).to(torch.float32)
|
||||
x_out = (rope * x_).sum(dim=-1)
|
||||
return x_out.reshape(*x.shape).to(torch.bfloat16)
|
||||
|
||||
|
||||
class Kandinsky5TimeEmbeddings(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, max_period=10000.0):
|
||||
super().__init__()
|
||||
@@ -312,184 +310,83 @@ class Kandinsky5Modulation(nn.Module):
|
||||
return self.out_layer(self.activation(x))
|
||||
|
||||
|
||||
class Kandinsky5SDPAAttentionProcessor(nn.Module):
|
||||
"""Custom attention processor for standard SDPA attention"""
|
||||
class Kandinsky5AttnProcessor:
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
**kwargs,
|
||||
):
|
||||
# Process attention with the given query, key, value tensors
|
||||
query = query.transpose(1, 2).contiguous()
|
||||
key = key.transpose(1, 2).contiguous()
|
||||
value = value.transpose(1, 2).contiguous()
|
||||
out = F.scaled_dot_product_attention(query, key, value).transpose(1, 2).contiguous().flatten(-2, -1)
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
return out
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
|
||||
|
||||
|
||||
class Kandinsky5NablaAttentionProcessor(nn.Module):
|
||||
"""Custom attention processor for Nabla attention"""
|
||||
|
||||
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True)
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
sparse_params=None,
|
||||
**kwargs,
|
||||
):
|
||||
if sparse_params is None:
|
||||
raise ValueError("sparse_params is required for Nabla attention")
|
||||
def __call__(self, attn, hidden_states, encoder_hidden_states=None, rotary_emb=None, sparse_params=None):
|
||||
# query, key, value = self.get_qkv(x)
|
||||
query = attn.to_query(hidden_states)
|
||||
|
||||
query = query.transpose(1, 2).contiguous()
|
||||
key = key.transpose(1, 2).contiguous()
|
||||
value = value.transpose(1, 2).contiguous()
|
||||
if encoder_hidden_states is not None:
|
||||
key = attn.to_key(encoder_hidden_states)
|
||||
value = attn.to_value(encoder_hidden_states)
|
||||
|
||||
block_mask = nablaT_v2(
|
||||
query,
|
||||
key,
|
||||
sparse_params["sta_mask"],
|
||||
thr=sparse_params["P"],
|
||||
)
|
||||
out = (
|
||||
flex_attention(query, key, value, block_mask=block_mask)
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
out = out.flatten(-2, -1)
|
||||
return out
|
||||
shape, cond_shape = query.shape[:-1], key.shape[:-1]
|
||||
query = query.reshape(*shape, attn.num_heads, -1)
|
||||
key = key.reshape(*cond_shape, attn.num_heads, -1)
|
||||
value = value.reshape(*cond_shape, attn.num_heads, -1)
|
||||
|
||||
else:
|
||||
key = attn.to_key(hidden_states)
|
||||
value = attn.to_value(hidden_states)
|
||||
|
||||
shape = query.shape[:-1]
|
||||
query = query.reshape(*shape, attn.num_heads, -1)
|
||||
key = key.reshape(*shape, attn.num_heads, -1)
|
||||
value = value.reshape(*shape, attn.num_heads, -1)
|
||||
|
||||
class Kandinsky5MultiheadSelfAttentionEnc(nn.Module):
|
||||
def __init__(self, num_channels, head_dim):
|
||||
super().__init__()
|
||||
assert num_channels % head_dim == 0
|
||||
self.num_heads = num_channels // head_dim
|
||||
# query, key = self.norm_qk(query, key)
|
||||
query = attn.query_norm(query.float()).type_as(query)
|
||||
key = attn.key_norm(key.float()).type_as(key)
|
||||
|
||||
self.to_query = nn.Linear(num_channels, num_channels, bias=True)
|
||||
self.to_key = nn.Linear(num_channels, num_channels, bias=True)
|
||||
self.to_value = nn.Linear(num_channels, num_channels, bias=True)
|
||||
self.query_norm = nn.RMSNorm(head_dim)
|
||||
self.key_norm = nn.RMSNorm(head_dim)
|
||||
def apply_rotary(x, rope):
|
||||
x_ = x.reshape(*x.shape[:-1], -1, 1, 2).to(torch.float32)
|
||||
x_out = (rope * x_).sum(dim=-1)
|
||||
return x_out.reshape(*x.shape).to(torch.bfloat16)
|
||||
|
||||
self.out_layer = nn.Linear(num_channels, num_channels, bias=True)
|
||||
|
||||
# Initialize attention processor
|
||||
self.sdpa_processor = Kandinsky5SDPAAttentionProcessor()
|
||||
|
||||
def get_qkv(self, x):
|
||||
query = self.to_query(x)
|
||||
key = self.to_key(x)
|
||||
value = self.to_value(x)
|
||||
|
||||
shape = query.shape[:-1]
|
||||
query = query.reshape(*shape, self.num_heads, -1)
|
||||
key = key.reshape(*shape, self.num_heads, -1)
|
||||
value = value.reshape(*shape, self.num_heads, -1)
|
||||
|
||||
return query, key, value
|
||||
|
||||
def norm_qk(self, q, k):
|
||||
q = self.query_norm(q.float()).type_as(q)
|
||||
k = self.key_norm(k.float()).type_as(k)
|
||||
return q, k
|
||||
|
||||
def scaled_dot_product_attention(self, query, key, value):
|
||||
# Use the processor
|
||||
return self.sdpa_processor(attn=self, query=query, key=key, value=value, **{})
|
||||
|
||||
def out_l(self, x):
|
||||
return self.out_layer(x)
|
||||
|
||||
def forward(self, x, rope):
|
||||
query, key, value = self.get_qkv(x)
|
||||
query, key = self.norm_qk(query, key)
|
||||
query = apply_rotary(query, rope).type_as(query)
|
||||
key = apply_rotary(key, rope).type_as(key)
|
||||
|
||||
out = self.scaled_dot_product_attention(query, key, value)
|
||||
|
||||
out = self.out_l(out)
|
||||
return out
|
||||
|
||||
|
||||
class Kandinsky5MultiheadSelfAttentionDec(nn.Module):
|
||||
def __init__(self, num_channels, head_dim):
|
||||
super().__init__()
|
||||
assert num_channels % head_dim == 0
|
||||
self.num_heads = num_channels // head_dim
|
||||
|
||||
self.to_query = nn.Linear(num_channels, num_channels, bias=True)
|
||||
self.to_key = nn.Linear(num_channels, num_channels, bias=True)
|
||||
self.to_value = nn.Linear(num_channels, num_channels, bias=True)
|
||||
self.query_norm = nn.RMSNorm(head_dim)
|
||||
self.key_norm = nn.RMSNorm(head_dim)
|
||||
|
||||
self.out_layer = nn.Linear(num_channels, num_channels, bias=True)
|
||||
|
||||
# Initialize attention processors
|
||||
self.sdpa_processor = Kandinsky5SDPAAttentionProcessor()
|
||||
self.nabla_processor = Kandinsky5NablaAttentionProcessor()
|
||||
|
||||
def get_qkv(self, x):
|
||||
query = self.to_query(x)
|
||||
key = self.to_key(x)
|
||||
value = self.to_value(x)
|
||||
|
||||
shape = query.shape[:-1]
|
||||
query = query.reshape(*shape, self.num_heads, -1)
|
||||
key = key.reshape(*shape, self.num_heads, -1)
|
||||
value = value.reshape(*shape, self.num_heads, -1)
|
||||
|
||||
return query, key, value
|
||||
|
||||
def norm_qk(self, q, k):
|
||||
q = self.query_norm(q.float()).type_as(q)
|
||||
k = self.key_norm(k.float()).type_as(k)
|
||||
return q, k
|
||||
|
||||
def attention(self, query, key, value):
|
||||
# Use the processor
|
||||
return self.sdpa_processor(attn=self, query=query, key=key, value=value, **{})
|
||||
|
||||
def nabla(self, query, key, value, sparse_params=None):
|
||||
# Use the processor
|
||||
return self.nabla_processor(
|
||||
attn=self,
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
sparse_params=sparse_params,
|
||||
**{},
|
||||
)
|
||||
|
||||
def out_l(self, x):
|
||||
return self.out_layer(x)
|
||||
|
||||
def forward(self, x, rope, sparse_params=None):
|
||||
query, key, value = self.get_qkv(x)
|
||||
query, key = self.norm_qk(query, key)
|
||||
query = apply_rotary(query, rope).type_as(query)
|
||||
key = apply_rotary(key, rope).type_as(key)
|
||||
if rotary_emb is not None:
|
||||
query = apply_rotary(query, rotary_emb).type_as(query)
|
||||
key = apply_rotary(key, rotary_emb).type_as(key)
|
||||
|
||||
if sparse_params is not None:
|
||||
out = self.nabla(query, key, value, sparse_params=sparse_params)
|
||||
attn_mask = nablaT_v2(
|
||||
query,
|
||||
key,
|
||||
sparse_params["sta_mask"],
|
||||
thr=sparse_params["P"],
|
||||
)
|
||||
else:
|
||||
out = self.attention(query, key, value)
|
||||
attn_mask = None
|
||||
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attn_mask,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(-2, -1)
|
||||
|
||||
out = self.out_l(out)
|
||||
return out
|
||||
attn_out = attn.out_layer(hidden_states)
|
||||
return attn_out
|
||||
|
||||
|
||||
class Kandinsky5MultiheadCrossAttention(nn.Module):
|
||||
def __init__(self, num_channels, head_dim):
|
||||
|
||||
class Kandinsky5Attention(nn.Module, AttentionModuleMixin):
|
||||
|
||||
_default_processor_cls = Kandinsky5AttnProcessor
|
||||
_available_processors = [
|
||||
Kandinsky5AttnProcessor,
|
||||
]
|
||||
def __init__(self, num_channels, head_dim, processor=None):
|
||||
super().__init__()
|
||||
assert num_channels % head_dim == 0
|
||||
self.num_heads = num_channels // head_dim
|
||||
@@ -501,42 +398,31 @@ class Kandinsky5MultiheadCrossAttention(nn.Module):
|
||||
self.key_norm = nn.RMSNorm(head_dim)
|
||||
|
||||
self.out_layer = nn.Linear(num_channels, num_channels, bias=True)
|
||||
if processor is None:
|
||||
processor = self._default_processor_cls()
|
||||
self.set_processor(processor)
|
||||
|
||||
# Initialize attention processor
|
||||
self.sdpa_processor = Kandinsky5SDPAAttentionProcessor()
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
sparse_params: Optional[torch.Tensor] = None,
|
||||
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
|
||||
def get_qkv(self, x, cond):
|
||||
query = self.to_query(x)
|
||||
key = self.to_key(cond)
|
||||
value = self.to_value(cond)
|
||||
import inspect
|
||||
|
||||
shape, cond_shape = query.shape[:-1], key.shape[:-1]
|
||||
query = query.reshape(*shape, self.num_heads, -1)
|
||||
key = key.reshape(*cond_shape, self.num_heads, -1)
|
||||
value = value.reshape(*cond_shape, self.num_heads, -1)
|
||||
|
||||
return query, key, value
|
||||
|
||||
def norm_qk(self, q, k):
|
||||
q = self.query_norm(q.float()).type_as(q)
|
||||
k = self.key_norm(k.float()).type_as(k)
|
||||
return q, k
|
||||
|
||||
def attention(self, query, key, value):
|
||||
# Use the processor
|
||||
return self.sdpa_processor(attn=self, query=query, key=key, value=value, **{})
|
||||
|
||||
def out_l(self, x):
|
||||
return self.out_layer(x)
|
||||
|
||||
def forward(self, x, cond):
|
||||
query, key, value = self.get_qkv(x, cond)
|
||||
query, key = self.norm_qk(query, key)
|
||||
|
||||
out = self.attention(query, key, value)
|
||||
out = self.out_l(out)
|
||||
return out
|
||||
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
||||
quiet_attn_parameters = {}
|
||||
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters]
|
||||
if len(unused_kwargs) > 0:
|
||||
logger.warning(
|
||||
f"attention_processor_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
|
||||
)
|
||||
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
|
||||
|
||||
return self.processor(self, hidden_states, encoder_hidden_states=encoder_hidden_states, sparse_params=sparse_params, rotary_emb=rotary_emb, **kwargs)
|
||||
|
||||
class Kandinsky5FeedForward(nn.Module):
|
||||
def __init__(self, dim, ff_dim):
|
||||
@@ -594,7 +480,7 @@ class Kandinsky5TransformerEncoderBlock(nn.Module):
|
||||
self.text_modulation = Kandinsky5Modulation(time_dim, model_dim, 6)
|
||||
|
||||
self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
|
||||
self.self_attention = Kandinsky5MultiheadSelfAttentionEnc(model_dim, head_dim)
|
||||
self.self_attention = Kandinsky5Attention(model_dim, head_dim, processor=Kandinsky5AttnProcessor())
|
||||
|
||||
self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
|
||||
self.feed_forward = Kandinsky5FeedForward(model_dim, ff_dim)
|
||||
@@ -605,7 +491,7 @@ class Kandinsky5TransformerEncoderBlock(nn.Module):
|
||||
)
|
||||
shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1)
|
||||
out = (self.self_attention_norm(x.float()) * (scale.float() + 1.0) + shift.float()).type_as(x)
|
||||
out = self.self_attention(out, rope)
|
||||
out = self.self_attention(out, rotary_emb=rope)
|
||||
x = (x.float() + gate.float() * out.float()).type_as(x)
|
||||
|
||||
shift, scale, gate = torch.chunk(ff_params, 3, dim=-1)
|
||||
@@ -622,10 +508,10 @@ class Kandinsky5TransformerDecoderBlock(nn.Module):
|
||||
self.visual_modulation = Kandinsky5Modulation(time_dim, model_dim, 9)
|
||||
|
||||
self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
|
||||
self.self_attention = Kandinsky5MultiheadSelfAttentionDec(model_dim, head_dim)
|
||||
self.self_attention = Kandinsky5Attention(model_dim, head_dim, processor=Kandinsky5AttnProcessor())
|
||||
|
||||
self.cross_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
|
||||
self.cross_attention = Kandinsky5MultiheadCrossAttention(model_dim, head_dim)
|
||||
self.cross_attention = Kandinsky5Attention(model_dim, head_dim, processor=Kandinsky5AttnProcessor())
|
||||
|
||||
self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
|
||||
self.feed_forward = Kandinsky5FeedForward(model_dim, ff_dim)
|
||||
@@ -637,12 +523,12 @@ class Kandinsky5TransformerDecoderBlock(nn.Module):
|
||||
|
||||
shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1)
|
||||
visual_out = (self.self_attention_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as(visual_embed)
|
||||
visual_out = self.self_attention(visual_out, rope, sparse_params)
|
||||
visual_out = self.self_attention(visual_out, rotary_emb=rope, sparse_params=sparse_params)
|
||||
visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed)
|
||||
|
||||
shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1)
|
||||
visual_out = (self.cross_attention_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as(visual_embed)
|
||||
visual_out = self.cross_attention(visual_out, text_embed)
|
||||
visual_out = self.cross_attention(visual_out, encoder_hidden_states=text_embed)
|
||||
visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed)
|
||||
|
||||
shift, scale, gate = torch.chunk(ff_params, 3, dim=-1)
|
||||
|
||||
Reference in New Issue
Block a user