From c8be08149e80ae22e7a7d3b4a1f2413a9f149690 Mon Sep 17 00:00:00 2001 From: leffff Date: Thu, 16 Oct 2025 09:31:12 +0000 Subject: [PATCH] Transformer: move all methods to forward --- .../transformers/transformer_kandinsky.py | 189 +++++------------- 1 file changed, 49 insertions(+), 140 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 8d3b4fac51..45e4238cfb 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -616,47 +616,6 @@ class Kandinsky5TransformerEncoderBlock(nn.Module): return x -# class Kandinsky5TransformerDecoderBlock(nn.Module): -# def __init__(self, model_dim, time_dim, ff_dim, head_dim): -# super().__init__() -# self.visual_modulation = Kandinsky5Modulation(time_dim, model_dim, 9) - -# self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False) -# self.self_attention = Kandinsky5MultiheadSelfAttentionDec(model_dim, head_dim) - -# self.cross_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False) -# self.cross_attention = Kandinsky5MultiheadCrossAttention(model_dim, head_dim) - -# self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False) -# self.feed_forward = Kandinsky5FeedForward(model_dim, ff_dim) - -# def forward(self, visual_embed, text_embed, time_embed, rope, sparse_params): -# self_attn_params, cross_attn_params, ff_params = torch.chunk( -# self.visual_modulation(time_embed).unsqueeze(dim=1), 3, dim=-1 -# ) -# shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) -# visual_out = apply_scale_shift_norm( -# self.self_attention_norm, visual_embed, scale, shift -# ) -# visual_out = self.self_attention(visual_out, rope, sparse_params) -# visual_embed = apply_gate_sum(visual_embed, visual_out, gate) - -# shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1) -# visual_out = apply_scale_shift_norm( -# self.cross_attention_norm, visual_embed, scale, shift -# ) -# visual_out = self.cross_attention(visual_out, text_embed) -# visual_embed = apply_gate_sum(visual_embed, visual_out, gate) - -# shift, scale, gate = torch.chunk(ff_params, 3, dim=-1) -# visual_out = apply_scale_shift_norm( -# self.feed_forward_norm, visual_embed, scale, shift -# ) -# visual_out = self.feed_forward(visual_out) -# visual_embed = apply_gate_sum(visual_embed, visual_out, gate) -# return visual_embed - - class Kandinsky5TransformerDecoderBlock(nn.Module): def __init__(self, model_dim, time_dim, ff_dim, head_dim): super().__init__() @@ -724,16 +683,16 @@ class Kandinsky5Transformer3DModel( axes_dims=(16, 24, 24), visual_cond=False, attention_type: str = "regular", - attention_causal: bool = None, # Default for Nabla: false - attention_local: bool = None, # Default for Nabla: false - attention_glob: bool = None, # Default for Nabla: false - attention_window: int = None, # Default for Nabla: 3 - attention_P: float = None, # Default for Nabla: 0.9 - attention_wT: int = None, # Default for Nabla: 11 - attention_wW: int = None, # Default for Nabla: 3 - attention_wH: int = None, # Default for Nabla: 3 - attention_add_sta: bool = None, # Default for Nabla: true - attention_method: str = None, # Default for Nabla: "topcdf" + attention_causal: bool = None, + attention_local: bool = None, + attention_glob: bool = None, + attention_window: int = None, + attention_P: float = None, + attention_wT: int = None, + attention_wW: int = None, + attention_wH: int = None, + attention_add_sta: bool = None, + attention_method: str = None, ): super().__init__() @@ -779,73 +738,6 @@ class Kandinsky5Transformer3DModel( ) self.gradient_checkpointing = False - def prepare_text_embeddings( - self, text_embed, time, pooled_text_embed, x, text_rope_pos - ): - """Prepare text embeddings and related components""" - text_embed = self.text_embeddings(text_embed) - time_embed = self.time_embeddings(time) - time_embed = time_embed + self.pooled_text_embeddings(pooled_text_embed) - visual_embed = self.visual_embeddings(x) - text_rope = self.text_rope_embeddings(text_rope_pos) - text_rope = text_rope.unsqueeze(dim=0) - return text_embed, time_embed, text_rope, visual_embed - - def prepare_visual_embeddings( - self, visual_embed, visual_rope_pos, scale_factor, sparse_params - ): - """Prepare visual embeddings and related components""" - visual_shape = visual_embed.shape[:-1] - visual_rope = self.visual_rope_embeddings( - visual_shape, visual_rope_pos, scale_factor - ) - to_fractal = sparse_params["to_fractal"] if sparse_params is not None else False - visual_embed, visual_rope = fractal_flatten( - visual_embed, visual_rope, visual_shape, block_mask=to_fractal - ) - return visual_embed, visual_shape, to_fractal, visual_rope - - def process_text_transformer_blocks(self, text_embed, time_embed, text_rope): - """Process text through transformer blocks""" - for text_transformer_block in self.text_transformer_blocks: - if torch.is_grad_enabled() and self.gradient_checkpointing: - text_embed = self._gradient_checkpointing_func( - text_transformer_block, text_embed, time_embed, text_rope - ) - else: - text_embed = text_transformer_block(text_embed, time_embed, text_rope) - return text_embed - - def process_visual_transformer_blocks( - self, visual_embed, text_embed, time_embed, visual_rope, sparse_params - ): - """Process visual through transformer blocks""" - for visual_transformer_block in self.visual_transformer_blocks: - if torch.is_grad_enabled() and self.gradient_checkpointing: - visual_embed = self._gradient_checkpointing_func( - visual_transformer_block, - visual_embed, - text_embed, - time_embed, - visual_rope, - sparse_params, - ) - else: - visual_embed = visual_transformer_block( - visual_embed, text_embed, time_embed, visual_rope, sparse_params - ) - return visual_embed - - def prepare_output( - self, visual_embed, visual_shape, to_fractal, text_embed, time_embed - ): - """Prepare the final output""" - visual_embed = fractal_unflatten( - visual_embed, visual_shape, block_mask=to_fractal - ) - x = self.out_layer(visual_embed, text_embed, time_embed) - return x - def forward( self, hidden_states: torch.FloatTensor, # x @@ -881,32 +773,49 @@ class Kandinsky5Transformer3DModel( time = timestep pooled_text_embed = pooled_projections - # Prepare text embeddings and related components - text_embed, time_embed, text_rope, visual_embed = self.prepare_text_embeddings( - text_embed, time, pooled_text_embed, x, text_rope_pos + text_embed = self.text_embeddings(text_embed) + time_embed = self.time_embeddings(time) + time_embed = time_embed + self.pooled_text_embeddings(pooled_text_embed) + visual_embed = self.visual_embeddings(x) + text_rope = self.text_rope_embeddings(text_rope_pos) + text_rope = text_rope.unsqueeze(dim=0) + + for text_transformer_block in self.text_transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + text_embed = self._gradient_checkpointing_func( + text_transformer_block, text_embed, time_embed, text_rope + ) + else: + text_embed = text_transformer_block(text_embed, time_embed, text_rope) + + visual_shape = visual_embed.shape[:-1] + visual_rope = self.visual_rope_embeddings( + visual_shape, visual_rope_pos, scale_factor + ) + to_fractal = sparse_params["to_fractal"] if sparse_params is not None else False + visual_embed, visual_rope = fractal_flatten( + visual_embed, visual_rope, visual_shape, block_mask=to_fractal ) - # Process text through transformer blocks - text_embed = self.process_text_transformer_blocks( - text_embed, time_embed, text_rope - ) + for visual_transformer_block in self.visual_transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + visual_embed = self._gradient_checkpointing_func( + visual_transformer_block, + visual_embed, + text_embed, + time_embed, + visual_rope, + sparse_params, + ) + else: + visual_embed = visual_transformer_block( + visual_embed, text_embed, time_embed, visual_rope, sparse_params + ) - # Prepare visual embeddings and related components - visual_embed, visual_shape, to_fractal, visual_rope = ( - self.prepare_visual_embeddings( - visual_embed, visual_rope_pos, scale_factor, sparse_params - ) - ) - - # Process visual through transformer blocks - visual_embed = self.process_visual_transformer_blocks( - visual_embed, text_embed, time_embed, visual_rope, sparse_params - ) - - # Prepare final output - x = self.prepare_output( - visual_embed, visual_shape, to_fractal, text_embed, time_embed + visual_embed = fractal_unflatten( + visual_embed, visual_shape, block_mask=to_fractal ) + x = self.out_layer(visual_embed, text_embed, time_embed) if not return_dict: return x