From acabbc0033d4b4933fc651766a4aa026db2e6dc1 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 17 Oct 2025 07:26:43 +0200 Subject: [PATCH] refactor attention --- .../transformers/transformer_kandinsky.py | 312 ++++++------------ 1 file changed, 99 insertions(+), 213 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 45e4238cfb..1c3f8a2f68 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -19,10 +19,6 @@ import torch import torch.nn as nn import torch.nn.functional as F from torch import BoolTensor, IntTensor, Tensor, nn -from torch.nn.attention.flex_attention import ( - BlockMask, - flex_attention, -) from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin @@ -34,7 +30,7 @@ from ...utils import ( unscale_lora_layers, ) from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import AttentionMixin, FeedForward +from ..attention import AttentionMixin, FeedForward, AttentionModuleMixin from ..cache_utils import CacheMixin from ..embeddings import ( TimestepEmbedding, @@ -43,6 +39,7 @@ from ..embeddings import ( from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import FP32LayerNorm +from ..attention_dispatch import dispatch_attention_fn, _CAN_USE_FLEX_ATTN logger = logging.get_logger(__name__) @@ -149,7 +146,15 @@ def nablaT_v2( k: Tensor, sta: Tensor, thr: float = 0.9, -) -> BlockMask: +): + if _CAN_USE_FLEX_ATTN: + from torch.nn.attention.flex_attention import BlockMask + else: + raise ValueError("Nabla attention is not supported with this version of PyTorch") + + q = q.transpose(1, 2).contiguous() + k = k.transpose(1, 2).contiguous() + # Map estimation B, h, S, D = q.shape s1 = S // 64 @@ -174,13 +179,6 @@ def nablaT_v2( ) -@torch.autocast(device_type="cuda", enabled=False) -def apply_rotary(x, rope): - x_ = x.reshape(*x.shape[:-1], -1, 1, 2).to(torch.float32) - x_out = (rope * x_).sum(dim=-1) - return x_out.reshape(*x.shape).to(torch.bfloat16) - - class Kandinsky5TimeEmbeddings(nn.Module): def __init__(self, model_dim, time_dim, max_period=10000.0): super().__init__() @@ -312,184 +310,83 @@ class Kandinsky5Modulation(nn.Module): return self.out_layer(self.activation(x)) -class Kandinsky5SDPAAttentionProcessor(nn.Module): - """Custom attention processor for standard SDPA attention""" +class Kandinsky5AttnProcessor: - def __call__( - self, - attn, - query, - key, - value, - **kwargs, - ): - # Process attention with the given query, key, value tensors - query = query.transpose(1, 2).contiguous() - key = key.transpose(1, 2).contiguous() - value = value.transpose(1, 2).contiguous() - out = F.scaled_dot_product_attention(query, key, value).transpose(1, 2).contiguous().flatten(-2, -1) + _attention_backend = None + _parallel_config = None - return out + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") -class Kandinsky5NablaAttentionProcessor(nn.Module): - """Custom attention processor for Nabla attention""" - - @torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True) - def __call__( - self, - attn, - query, - key, - value, - sparse_params=None, - **kwargs, - ): - if sparse_params is None: - raise ValueError("sparse_params is required for Nabla attention") + def __call__(self, attn, hidden_states, encoder_hidden_states=None, rotary_emb=None, sparse_params=None): + # query, key, value = self.get_qkv(x) + query = attn.to_query(hidden_states) - query = query.transpose(1, 2).contiguous() - key = key.transpose(1, 2).contiguous() - value = value.transpose(1, 2).contiguous() + if encoder_hidden_states is not None: + key = attn.to_key(encoder_hidden_states) + value = attn.to_value(encoder_hidden_states) - block_mask = nablaT_v2( - query, - key, - sparse_params["sta_mask"], - thr=sparse_params["P"], - ) - out = ( - flex_attention(query, key, value, block_mask=block_mask) - .transpose(1, 2) - .contiguous() - ) - out = out.flatten(-2, -1) - return out + shape, cond_shape = query.shape[:-1], key.shape[:-1] + query = query.reshape(*shape, attn.num_heads, -1) + key = key.reshape(*cond_shape, attn.num_heads, -1) + value = value.reshape(*cond_shape, attn.num_heads, -1) + + else: + key = attn.to_key(hidden_states) + value = attn.to_value(hidden_states) + shape = query.shape[:-1] + query = query.reshape(*shape, attn.num_heads, -1) + key = key.reshape(*shape, attn.num_heads, -1) + value = value.reshape(*shape, attn.num_heads, -1) -class Kandinsky5MultiheadSelfAttentionEnc(nn.Module): - def __init__(self, num_channels, head_dim): - super().__init__() - assert num_channels % head_dim == 0 - self.num_heads = num_channels // head_dim + # query, key = self.norm_qk(query, key) + query = attn.query_norm(query.float()).type_as(query) + key = attn.key_norm(key.float()).type_as(key) - self.to_query = nn.Linear(num_channels, num_channels, bias=True) - self.to_key = nn.Linear(num_channels, num_channels, bias=True) - self.to_value = nn.Linear(num_channels, num_channels, bias=True) - self.query_norm = nn.RMSNorm(head_dim) - self.key_norm = nn.RMSNorm(head_dim) + def apply_rotary(x, rope): + x_ = x.reshape(*x.shape[:-1], -1, 1, 2).to(torch.float32) + x_out = (rope * x_).sum(dim=-1) + return x_out.reshape(*x.shape).to(torch.bfloat16) - self.out_layer = nn.Linear(num_channels, num_channels, bias=True) - - # Initialize attention processor - self.sdpa_processor = Kandinsky5SDPAAttentionProcessor() - - def get_qkv(self, x): - query = self.to_query(x) - key = self.to_key(x) - value = self.to_value(x) - - shape = query.shape[:-1] - query = query.reshape(*shape, self.num_heads, -1) - key = key.reshape(*shape, self.num_heads, -1) - value = value.reshape(*shape, self.num_heads, -1) - - return query, key, value - - def norm_qk(self, q, k): - q = self.query_norm(q.float()).type_as(q) - k = self.key_norm(k.float()).type_as(k) - return q, k - - def scaled_dot_product_attention(self, query, key, value): - # Use the processor - return self.sdpa_processor(attn=self, query=query, key=key, value=value, **{}) - - def out_l(self, x): - return self.out_layer(x) - - def forward(self, x, rope): - query, key, value = self.get_qkv(x) - query, key = self.norm_qk(query, key) - query = apply_rotary(query, rope).type_as(query) - key = apply_rotary(key, rope).type_as(key) - - out = self.scaled_dot_product_attention(query, key, value) - - out = self.out_l(out) - return out - - -class Kandinsky5MultiheadSelfAttentionDec(nn.Module): - def __init__(self, num_channels, head_dim): - super().__init__() - assert num_channels % head_dim == 0 - self.num_heads = num_channels // head_dim - - self.to_query = nn.Linear(num_channels, num_channels, bias=True) - self.to_key = nn.Linear(num_channels, num_channels, bias=True) - self.to_value = nn.Linear(num_channels, num_channels, bias=True) - self.query_norm = nn.RMSNorm(head_dim) - self.key_norm = nn.RMSNorm(head_dim) - - self.out_layer = nn.Linear(num_channels, num_channels, bias=True) - - # Initialize attention processors - self.sdpa_processor = Kandinsky5SDPAAttentionProcessor() - self.nabla_processor = Kandinsky5NablaAttentionProcessor() - - def get_qkv(self, x): - query = self.to_query(x) - key = self.to_key(x) - value = self.to_value(x) - - shape = query.shape[:-1] - query = query.reshape(*shape, self.num_heads, -1) - key = key.reshape(*shape, self.num_heads, -1) - value = value.reshape(*shape, self.num_heads, -1) - - return query, key, value - - def norm_qk(self, q, k): - q = self.query_norm(q.float()).type_as(q) - k = self.key_norm(k.float()).type_as(k) - return q, k - - def attention(self, query, key, value): - # Use the processor - return self.sdpa_processor(attn=self, query=query, key=key, value=value, **{}) - - def nabla(self, query, key, value, sparse_params=None): - # Use the processor - return self.nabla_processor( - attn=self, - query=query, - key=key, - value=value, - sparse_params=sparse_params, - **{}, - ) - - def out_l(self, x): - return self.out_layer(x) - - def forward(self, x, rope, sparse_params=None): - query, key, value = self.get_qkv(x) - query, key = self.norm_qk(query, key) - query = apply_rotary(query, rope).type_as(query) - key = apply_rotary(key, rope).type_as(key) + if rotary_emb is not None: + query = apply_rotary(query, rotary_emb).type_as(query) + key = apply_rotary(key, rotary_emb).type_as(key) if sparse_params is not None: - out = self.nabla(query, key, value, sparse_params=sparse_params) + attn_mask = nablaT_v2( + query, + key, + sparse_params["sta_mask"], + thr=sparse_params["P"], + ) else: - out = self.attention(query, key, value) + attn_mask = None + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attn_mask, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states = hidden_states.flatten(-2, -1) - out = self.out_l(out) - return out + attn_out = attn.out_layer(hidden_states) + return attn_out -class Kandinsky5MultiheadCrossAttention(nn.Module): - def __init__(self, num_channels, head_dim): + +class Kandinsky5Attention(nn.Module, AttentionModuleMixin): + + _default_processor_cls = Kandinsky5AttnProcessor + _available_processors = [ + Kandinsky5AttnProcessor, + ] + def __init__(self, num_channels, head_dim, processor=None): super().__init__() assert num_channels % head_dim == 0 self.num_heads = num_channels // head_dim @@ -501,42 +398,31 @@ class Kandinsky5MultiheadCrossAttention(nn.Module): self.key_norm = nn.RMSNorm(head_dim) self.out_layer = nn.Linear(num_channels, num_channels, bias=True) + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) - # Initialize attention processor - self.sdpa_processor = Kandinsky5SDPAAttentionProcessor() + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + sparse_params: Optional[torch.Tensor] = None, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> torch.Tensor: - def get_qkv(self, x, cond): - query = self.to_query(x) - key = self.to_key(cond) - value = self.to_value(cond) + import inspect - shape, cond_shape = query.shape[:-1], key.shape[:-1] - query = query.reshape(*shape, self.num_heads, -1) - key = key.reshape(*cond_shape, self.num_heads, -1) - value = value.reshape(*cond_shape, self.num_heads, -1) - - return query, key, value - - def norm_qk(self, q, k): - q = self.query_norm(q.float()).type_as(q) - k = self.key_norm(k.float()).type_as(k) - return q, k - - def attention(self, query, key, value): - # Use the processor - return self.sdpa_processor(attn=self, query=query, key=key, value=value, **{}) - - def out_l(self, x): - return self.out_layer(x) - - def forward(self, x, cond): - query, key, value = self.get_qkv(x, cond) - query, key = self.norm_qk(query, key) - - out = self.attention(query, key, value) - out = self.out_l(out) - return out + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + quiet_attn_parameters = {} + unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"attention_processor_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + return self.processor(self, hidden_states, encoder_hidden_states=encoder_hidden_states, sparse_params=sparse_params, rotary_emb=rotary_emb, **kwargs) class Kandinsky5FeedForward(nn.Module): def __init__(self, dim, ff_dim): @@ -594,7 +480,7 @@ class Kandinsky5TransformerEncoderBlock(nn.Module): self.text_modulation = Kandinsky5Modulation(time_dim, model_dim, 6) self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False) - self.self_attention = Kandinsky5MultiheadSelfAttentionEnc(model_dim, head_dim) + self.self_attention = Kandinsky5Attention(model_dim, head_dim, processor=Kandinsky5AttnProcessor()) self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False) self.feed_forward = Kandinsky5FeedForward(model_dim, ff_dim) @@ -605,7 +491,7 @@ class Kandinsky5TransformerEncoderBlock(nn.Module): ) shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) out = (self.self_attention_norm(x.float()) * (scale.float() + 1.0) + shift.float()).type_as(x) - out = self.self_attention(out, rope) + out = self.self_attention(out, rotary_emb=rope) x = (x.float() + gate.float() * out.float()).type_as(x) shift, scale, gate = torch.chunk(ff_params, 3, dim=-1) @@ -622,10 +508,10 @@ class Kandinsky5TransformerDecoderBlock(nn.Module): self.visual_modulation = Kandinsky5Modulation(time_dim, model_dim, 9) self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False) - self.self_attention = Kandinsky5MultiheadSelfAttentionDec(model_dim, head_dim) + self.self_attention = Kandinsky5Attention(model_dim, head_dim, processor=Kandinsky5AttnProcessor()) self.cross_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False) - self.cross_attention = Kandinsky5MultiheadCrossAttention(model_dim, head_dim) + self.cross_attention = Kandinsky5Attention(model_dim, head_dim, processor=Kandinsky5AttnProcessor()) self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False) self.feed_forward = Kandinsky5FeedForward(model_dim, ff_dim) @@ -637,12 +523,12 @@ class Kandinsky5TransformerDecoderBlock(nn.Module): shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) visual_out = (self.self_attention_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as(visual_embed) - visual_out = self.self_attention(visual_out, rope, sparse_params) + visual_out = self.self_attention(visual_out, rotary_emb=rope, sparse_params=sparse_params) visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed) shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1) visual_out = (self.cross_attention_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as(visual_embed) - visual_out = self.cross_attention(visual_out, text_embed) + visual_out = self.cross_attention(visual_out, encoder_hidden_states=text_embed) visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed) shift, scale, gate = torch.chunk(ff_params, 3, dim=-1)