From 7084106eaaa9b998efd520e72b4a69a6e2dd90cf Mon Sep 17 00:00:00 2001 From: leffff Date: Tue, 14 Oct 2025 20:38:40 +0000 Subject: [PATCH] remove unused imports --- .../transformers/transformer_kandinsky.py | 250 ++++++++++-------- 1 file changed, 142 insertions(+), 108 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 24b2c4ae99..ac2fe58d60 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -19,21 +19,27 @@ import torch import torch.nn as nn import torch.nn.functional as F from torch import BoolTensor, IntTensor, Tensor, nn -from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature, - flex_attention) -from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from torch.nn.attention.flex_attention import ( + BlockMask, + flex_attention, +) from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import (USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, - unscale_lora_layers) +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + scale_lora_layers, + unscale_lora_layers, +) from ...utils.torch_utils import maybe_allow_in_graph -from .._modeling_parallel import ContextParallelInput, ContextParallelOutput -from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward -from ..attention_dispatch import dispatch_attention_fn +from ..attention import AttentionMixin, FeedForward from ..cache_utils import CacheMixin -from ..embeddings import (PixArtAlphaTextProjection, TimestepEmbedding, - Timesteps, get_1d_rotary_pos_embed) +from ..embeddings import ( + TimestepEmbedding, + get_1d_rotary_pos_embed, +) from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import FP32LayerNorm @@ -95,7 +101,7 @@ def local_patching(x, shape, group_size, dim=0): g2, width // g3, g3, - *x.shape[dim + 3 :] + *x.shape[dim + 3 :], ) x = x.permute( *range(len(x.shape[:dim])), @@ -105,7 +111,7 @@ def local_patching(x, shape, group_size, dim=0): dim + 1, dim + 3, dim + 5, - *range(dim + 6, len(x.shape)) + *range(dim + 6, len(x.shape)), ) x = x.flatten(dim, dim + 2).flatten(dim + 1, dim + 3) return x @@ -122,7 +128,7 @@ def local_merge(x, shape, group_size, dim=0): g1, g2, g3, - *x.shape[dim + 2 :] + *x.shape[dim + 2 :], ) x = x.permute( *range(len(x.shape[:dim])), @@ -132,7 +138,7 @@ def local_merge(x, shape, group_size, dim=0): dim + 4, dim + 2, dim + 5, - *range(dim + 6, len(x.shape)) + *range(dim + 6, len(x.shape)), ) x = x.flatten(dim, dim + 1).flatten(dim + 1, dim + 2).flatten(dim + 2, dim + 3) return x @@ -172,15 +178,7 @@ def sdpa(q, k, v): query = q.transpose(1, 2).contiguous() key = k.transpose(1, 2).contiguous() value = v.transpose(1, 2).contiguous() - out = ( - F.scaled_dot_product_attention( - query, - key, - value - ) - .transpose(1, 2) - .contiguous() - ) + out = F.scaled_dot_product_attention(query, key, value).transpose(1, 2).contiguous() return out @@ -279,7 +277,7 @@ class Kandinsky5RoPE1D(nn.Module): rope = torch.stack([cosine, -sine, sine, cosine], dim=-1) rope = rope.view(*rope.shape[:-1], 2, 2) return rope.unsqueeze(-4) - + def reset_dtype(self): freq = get_freqs(self.dim // 2, self.max_period).to(self.args.device) pos = torch.arange(self.max_pos, dtype=freq.dtype, device=freq.device) @@ -307,9 +305,15 @@ class Kandinsky5RoPE3D(nn.Module): args = torch.cat( [ - args_t.view(1, duration, 1, 1, -1).repeat(batch_size, 1, height, width, 1), - args_h.view(1, 1, height, 1, -1).repeat(batch_size, duration, 1, width, 1), - args_w.view(1, 1, 1, width, -1).repeat(batch_size, duration, height, 1, 1), + args_t.view(1, duration, 1, 1, -1).repeat( + batch_size, 1, height, width, 1 + ), + args_h.view(1, 1, height, 1, -1).repeat( + batch_size, duration, 1, width, 1 + ), + args_w.view(1, 1, 1, width, -1).repeat( + batch_size, duration, height, 1, 1 + ), ], dim=-1, ) @@ -318,12 +322,12 @@ class Kandinsky5RoPE3D(nn.Module): rope = torch.stack([cosine, -sine, sine, cosine], dim=-1) rope = rope.view(*rope.shape[:-1], 2, 2) return rope.unsqueeze(-4) - + def reset_dtype(self): for i, (axes_dim, ax_max_pos) in enumerate(zip(self.axes_dims, self.max_pos)): freq = get_freqs(axes_dim // 2, self.max_period).to(self.args_0.device) pos = torch.arange(ax_max_pos, dtype=freq.dtype, device=freq.device) - setattr(self, f'args_{i}', torch.outer(pos, freq)) + setattr(self, f"args_{i}", torch.outer(pos, freq)) class Kandinsky5Modulation(nn.Module): @@ -341,7 +345,7 @@ class Kandinsky5Modulation(nn.Module): class Kandinsky5SDPAAttentionProcessor(nn.Module): """Custom attention processor for standard SDPA attention""" - + def __call__( self, attn, @@ -357,7 +361,7 @@ class Kandinsky5SDPAAttentionProcessor(nn.Module): class Kandinsky5NablaAttentionProcessor(nn.Module): """Custom attention processor for Nabla attention""" - + def __call__( self, attn, @@ -369,11 +373,11 @@ class Kandinsky5NablaAttentionProcessor(nn.Module): ): if sparse_params is None: raise ValueError("sparse_params is required for Nabla attention") - + query = query.transpose(1, 2).contiguous() key = key.transpose(1, 2).contiguous() value = value.transpose(1, 2).contiguous() - + block_mask = nablaT_v2( query, key, @@ -381,12 +385,7 @@ class Kandinsky5NablaAttentionProcessor(nn.Module): thr=sparse_params["P"], ) out = ( - flex_attention( - query, - key, - value, - block_mask=block_mask - ) + flex_attention(query, key, value, block_mask=block_mask) .transpose(1, 2) .contiguous() ) @@ -407,7 +406,7 @@ class Kandinsky5MultiheadSelfAttentionEnc(nn.Module): self.key_norm = nn.RMSNorm(head_dim) self.out_layer = nn.Linear(num_channels, num_channels, bias=True) - + # Initialize attention processor self.sdpa_processor = Kandinsky5SDPAAttentionProcessor() @@ -430,13 +429,7 @@ class Kandinsky5MultiheadSelfAttentionEnc(nn.Module): def scaled_dot_product_attention(self, query, key, value): # Use the processor - return self.sdpa_processor( - attn=self, - query=query, - key=key, - value=value, - **{} - ) + return self.sdpa_processor(attn=self, query=query, key=key, value=value, **{}) def out_l(self, x): return self.out_layer(x) @@ -466,7 +459,7 @@ class Kandinsky5MultiheadSelfAttentionDec(nn.Module): self.key_norm = nn.RMSNorm(head_dim) self.out_layer = nn.Linear(num_channels, num_channels, bias=True) - + # Initialize attention processors self.sdpa_processor = Kandinsky5SDPAAttentionProcessor() self.nabla_processor = Kandinsky5NablaAttentionProcessor() @@ -490,14 +483,8 @@ class Kandinsky5MultiheadSelfAttentionDec(nn.Module): def attention(self, query, key, value): # Use the processor - return self.sdpa_processor( - attn=self, - query=query, - key=key, - value=value, - **{} - ) - + return self.sdpa_processor(attn=self, query=query, key=key, value=value, **{}) + def nabla(self, query, key, value, sparse_params=None): # Use the processor return self.nabla_processor( @@ -506,7 +493,7 @@ class Kandinsky5MultiheadSelfAttentionDec(nn.Module): key=key, value=value, sparse_params=sparse_params, - **{} + **{}, ) def out_l(self, x): @@ -540,7 +527,7 @@ class Kandinsky5MultiheadCrossAttention(nn.Module): self.key_norm = nn.RMSNorm(head_dim) self.out_layer = nn.Linear(num_channels, num_channels, bias=True) - + # Initialize attention processor self.sdpa_processor = Kandinsky5SDPAAttentionProcessor() @@ -563,13 +550,7 @@ class Kandinsky5MultiheadCrossAttention(nn.Module): def attention(self, query, key, value): # Use the processor - return self.sdpa_processor( - attn=self, - query=query, - key=key, - value=value, - **{} - ) + return self.sdpa_processor(attn=self, query=query, key=key, value=value, **{}) def out_l(self, x): return self.out_layer(x) @@ -605,7 +586,9 @@ class Kandinsky5OutLayer(nn.Module): ) def forward(self, visual_embed, text_embed, time_embed): - shift, scale = torch.chunk(self.modulation(time_embed).unsqueeze(dim=1), 2, dim=-1) + shift, scale = torch.chunk( + self.modulation(time_embed).unsqueeze(dim=1), 2, dim=-1 + ) visual_embed = apply_scale_shift_norm( self.norm, visual_embed, @@ -646,7 +629,9 @@ class Kandinsky5TransformerEncoderBlock(nn.Module): self.feed_forward = Kandinsky5FeedForward(model_dim, ff_dim) def forward(self, x, time_embed, rope): - self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1) + self_attn_params, ff_params = torch.chunk( + self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1 + ) shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift) out = self.self_attention(out, rope) @@ -678,26 +663,40 @@ class Kandinsky5TransformerDecoderBlock(nn.Module): self.visual_modulation(time_embed).unsqueeze(dim=1), 3, dim=-1 ) shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) - visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift) + visual_out = apply_scale_shift_norm( + self.self_attention_norm, visual_embed, scale, shift + ) visual_out = self.self_attention(visual_out, rope, sparse_params) visual_embed = apply_gate_sum(visual_embed, visual_out, gate) shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1) - visual_out = apply_scale_shift_norm(self.cross_attention_norm, visual_embed, scale, shift) + visual_out = apply_scale_shift_norm( + self.cross_attention_norm, visual_embed, scale, shift + ) visual_out = self.cross_attention(visual_out, text_embed) visual_embed = apply_gate_sum(visual_embed, visual_out, gate) shift, scale, gate = torch.chunk(ff_params, 3, dim=-1) - visual_out = apply_scale_shift_norm(self.feed_forward_norm, visual_embed, scale, shift) + visual_out = apply_scale_shift_norm( + self.feed_forward_norm, visual_embed, scale, shift + ) visual_out = self.feed_forward(visual_out) visual_embed = apply_gate_sum(visual_embed, visual_out, gate) return visual_embed -class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin): +class Kandinsky5Transformer3DModel( + ModelMixin, + ConfigMixin, + PeftAdapterMixin, + FromOriginalModelMixin, + CacheMixin, + AttentionMixin, +): """ A 3D Diffusion Transformer model for video-like data. """ + _supports_gradient_checkpointing = True @register_to_config @@ -714,21 +713,21 @@ class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fr num_text_blocks=2, num_visual_blocks=32, axes_dims=(16, 24, 24), - visual_cond=False, + visual_cond=False, attention_type: str = "regular", - attention_causal: bool = None, # Default for Nabla: false - attention_local: bool = None, # Default for Nabla: false - attention_glob: bool = None, # Default for Nabla: false - attention_window: int = None, # Default for Nabla: 3 - attention_P: float = None, # Default for Nabla: 0.9 - attention_wT: int = None, # Default for Nabla: 11 - attention_wW: int = None, # Default for Nabla: 3 - attention_wH: int = None, # Default for Nabla: 3 - attention_add_sta: bool = None, # Default for Nabla: true - attention_method: str = None, # Default for Nabla: "topcdf" + attention_causal: bool = None, # Default for Nabla: false + attention_local: bool = None, # Default for Nabla: false + attention_glob: bool = None, # Default for Nabla: false + attention_window: int = None, # Default for Nabla: 3 + attention_P: float = None, # Default for Nabla: 0.9 + attention_wT: int = None, # Default for Nabla: 11 + attention_wW: int = None, # Default for Nabla: 3 + attention_wH: int = None, # Default for Nabla: 3 + attention_add_sta: bool = None, # Default for Nabla: true + attention_method: str = None, # Default for Nabla: "topcdf" ): super().__init__() - + head_dim = sum(axes_dims) self.in_visual_dim = in_visual_dim self.model_dim = model_dim @@ -737,12 +736,14 @@ class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fr self.attention_type = attention_type visual_embed_dim = 2 * in_visual_dim + 1 if visual_cond else in_visual_dim - + # Initialize embeddings self.time_embeddings = Kandinsky5TimeEmbeddings(model_dim, time_dim) self.text_embeddings = Kandinsky5TextEmbeddings(in_text_dim, model_dim) self.pooled_text_embeddings = Kandinsky5TextEmbeddings(in_text_dim2, time_dim) - self.visual_embeddings = Kandinsky5VisualEmbeddings(visual_embed_dim, model_dim, patch_size) + self.visual_embeddings = Kandinsky5VisualEmbeddings( + visual_embed_dim, model_dim, patch_size + ) # Initialize positional embeddings self.text_rope_embeddings = Kandinsky5RoPE1D(head_dim) @@ -764,10 +765,14 @@ class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fr ) # Initialize output layer - self.out_layer = Kandinsky5OutLayer(model_dim, time_dim, out_visual_dim, patch_size) + self.out_layer = Kandinsky5OutLayer( + model_dim, time_dim, out_visual_dim, patch_size + ) self.gradient_checkpointing = False - def prepare_text_embeddings(self, text_embed, time, pooled_text_embed, x, text_rope_pos): + def prepare_text_embeddings( + self, text_embed, time, pooled_text_embed, x, text_rope_pos + ): """Prepare text embeddings and related components""" text_embed = self.text_embeddings(text_embed) time_embed = self.time_embeddings(time) @@ -777,38 +782,58 @@ class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fr text_rope = text_rope.unsqueeze(dim=0) return text_embed, time_embed, text_rope, visual_embed - def prepare_visual_embeddings(self, visual_embed, visual_rope_pos, scale_factor, sparse_params): + def prepare_visual_embeddings( + self, visual_embed, visual_rope_pos, scale_factor, sparse_params + ): """Prepare visual embeddings and related components""" visual_shape = visual_embed.shape[:-1] - visual_rope = self.visual_rope_embeddings(visual_shape, visual_rope_pos, scale_factor) + visual_rope = self.visual_rope_embeddings( + visual_shape, visual_rope_pos, scale_factor + ) to_fractal = sparse_params["to_fractal"] if sparse_params is not None else False - visual_embed, visual_rope = fractal_flatten(visual_embed, visual_rope, visual_shape, - block_mask=to_fractal) + visual_embed, visual_rope = fractal_flatten( + visual_embed, visual_rope, visual_shape, block_mask=to_fractal + ) return visual_embed, visual_shape, to_fractal, visual_rope def process_text_transformer_blocks(self, text_embed, time_embed, text_rope): """Process text through transformer blocks""" for text_transformer_block in self.text_transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: - text_embed = self._gradient_checkpointing_func(text_transformer_block, text_embed, time_embed, text_rope) + text_embed = self._gradient_checkpointing_func( + text_transformer_block, text_embed, time_embed, text_rope + ) else: text_embed = text_transformer_block(text_embed, time_embed, text_rope) return text_embed - def process_visual_transformer_blocks(self, visual_embed, text_embed, time_embed, visual_rope, sparse_params): + def process_visual_transformer_blocks( + self, visual_embed, text_embed, time_embed, visual_rope, sparse_params + ): """Process visual through transformer blocks""" for visual_transformer_block in self.visual_transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: - visual_embed = self._gradient_checkpointing_func(visual_transformer_block, visual_embed, text_embed, time_embed, - visual_rope, sparse_params) + visual_embed = self._gradient_checkpointing_func( + visual_transformer_block, + visual_embed, + text_embed, + time_embed, + visual_rope, + sparse_params, + ) else: - visual_embed = visual_transformer_block(visual_embed, text_embed, time_embed, - visual_rope, sparse_params) + visual_embed = visual_transformer_block( + visual_embed, text_embed, time_embed, visual_rope, sparse_params + ) return visual_embed - def prepare_output(self, visual_embed, visual_shape, to_fractal, text_embed, time_embed): + def prepare_output( + self, visual_embed, visual_shape, to_fractal, text_embed, time_embed + ): """Prepare the final output""" - visual_embed = fractal_unflatten(visual_embed, visual_shape, block_mask=to_fractal) + visual_embed = fractal_unflatten( + visual_embed, visual_shape, block_mask=to_fractal + ) x = self.out_layer(visual_embed, text_embed, time_embed) return x @@ -846,25 +871,34 @@ class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fr text_embed = encoder_hidden_states time = timestep pooled_text_embed = pooled_projections - + # Prepare text embeddings and related components text_embed, time_embed, text_rope, visual_embed = self.prepare_text_embeddings( - text_embed, time, pooled_text_embed, x, text_rope_pos) + text_embed, time, pooled_text_embed, x, text_rope_pos + ) # Process text through transformer blocks - text_embed = self.process_text_transformer_blocks(text_embed, time_embed, text_rope) + text_embed = self.process_text_transformer_blocks( + text_embed, time_embed, text_rope + ) # Prepare visual embeddings and related components - visual_embed, visual_shape, to_fractal, visual_rope = self.prepare_visual_embeddings( - visual_embed, visual_rope_pos, scale_factor, sparse_params) + visual_embed, visual_shape, to_fractal, visual_rope = ( + self.prepare_visual_embeddings( + visual_embed, visual_rope_pos, scale_factor, sparse_params + ) + ) # Process visual through transformer blocks visual_embed = self.process_visual_transformer_blocks( - visual_embed, text_embed, time_embed, visual_rope, sparse_params) - + visual_embed, text_embed, time_embed, visual_rope, sparse_params + ) + # Prepare final output - x = self.prepare_output(visual_embed, visual_shape, to_fractal, text_embed, time_embed) - + x = self.prepare_output( + visual_embed, visual_shape, to_fractal, text_embed, time_embed + ) + if not return_dict: return x