1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Transformer: move all methods to forward

This commit is contained in:
leffff
2025-10-16 09:31:12 +00:00
parent 894aa98a27
commit c8be08149e

View File

@@ -616,47 +616,6 @@ class Kandinsky5TransformerEncoderBlock(nn.Module):
return x
# class Kandinsky5TransformerDecoderBlock(nn.Module):
# def __init__(self, model_dim, time_dim, ff_dim, head_dim):
# super().__init__()
# self.visual_modulation = Kandinsky5Modulation(time_dim, model_dim, 9)
# self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
# self.self_attention = Kandinsky5MultiheadSelfAttentionDec(model_dim, head_dim)
# self.cross_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
# self.cross_attention = Kandinsky5MultiheadCrossAttention(model_dim, head_dim)
# self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
# self.feed_forward = Kandinsky5FeedForward(model_dim, ff_dim)
# def forward(self, visual_embed, text_embed, time_embed, rope, sparse_params):
# self_attn_params, cross_attn_params, ff_params = torch.chunk(
# self.visual_modulation(time_embed).unsqueeze(dim=1), 3, dim=-1
# )
# shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1)
# visual_out = apply_scale_shift_norm(
# self.self_attention_norm, visual_embed, scale, shift
# )
# visual_out = self.self_attention(visual_out, rope, sparse_params)
# visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
# shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1)
# visual_out = apply_scale_shift_norm(
# self.cross_attention_norm, visual_embed, scale, shift
# )
# visual_out = self.cross_attention(visual_out, text_embed)
# visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
# shift, scale, gate = torch.chunk(ff_params, 3, dim=-1)
# visual_out = apply_scale_shift_norm(
# self.feed_forward_norm, visual_embed, scale, shift
# )
# visual_out = self.feed_forward(visual_out)
# visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
# return visual_embed
class Kandinsky5TransformerDecoderBlock(nn.Module):
def __init__(self, model_dim, time_dim, ff_dim, head_dim):
super().__init__()
@@ -724,16 +683,16 @@ class Kandinsky5Transformer3DModel(
axes_dims=(16, 24, 24),
visual_cond=False,
attention_type: str = "regular",
attention_causal: bool = None, # Default for Nabla: false
attention_local: bool = None, # Default for Nabla: false
attention_glob: bool = None, # Default for Nabla: false
attention_window: int = None, # Default for Nabla: 3
attention_P: float = None, # Default for Nabla: 0.9
attention_wT: int = None, # Default for Nabla: 11
attention_wW: int = None, # Default for Nabla: 3
attention_wH: int = None, # Default for Nabla: 3
attention_add_sta: bool = None, # Default for Nabla: true
attention_method: str = None, # Default for Nabla: "topcdf"
attention_causal: bool = None,
attention_local: bool = None,
attention_glob: bool = None,
attention_window: int = None,
attention_P: float = None,
attention_wT: int = None,
attention_wW: int = None,
attention_wH: int = None,
attention_add_sta: bool = None,
attention_method: str = None,
):
super().__init__()
@@ -779,73 +738,6 @@ class Kandinsky5Transformer3DModel(
)
self.gradient_checkpointing = False
def prepare_text_embeddings(
self, text_embed, time, pooled_text_embed, x, text_rope_pos
):
"""Prepare text embeddings and related components"""
text_embed = self.text_embeddings(text_embed)
time_embed = self.time_embeddings(time)
time_embed = time_embed + self.pooled_text_embeddings(pooled_text_embed)
visual_embed = self.visual_embeddings(x)
text_rope = self.text_rope_embeddings(text_rope_pos)
text_rope = text_rope.unsqueeze(dim=0)
return text_embed, time_embed, text_rope, visual_embed
def prepare_visual_embeddings(
self, visual_embed, visual_rope_pos, scale_factor, sparse_params
):
"""Prepare visual embeddings and related components"""
visual_shape = visual_embed.shape[:-1]
visual_rope = self.visual_rope_embeddings(
visual_shape, visual_rope_pos, scale_factor
)
to_fractal = sparse_params["to_fractal"] if sparse_params is not None else False
visual_embed, visual_rope = fractal_flatten(
visual_embed, visual_rope, visual_shape, block_mask=to_fractal
)
return visual_embed, visual_shape, to_fractal, visual_rope
def process_text_transformer_blocks(self, text_embed, time_embed, text_rope):
"""Process text through transformer blocks"""
for text_transformer_block in self.text_transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
text_embed = self._gradient_checkpointing_func(
text_transformer_block, text_embed, time_embed, text_rope
)
else:
text_embed = text_transformer_block(text_embed, time_embed, text_rope)
return text_embed
def process_visual_transformer_blocks(
self, visual_embed, text_embed, time_embed, visual_rope, sparse_params
):
"""Process visual through transformer blocks"""
for visual_transformer_block in self.visual_transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
visual_embed = self._gradient_checkpointing_func(
visual_transformer_block,
visual_embed,
text_embed,
time_embed,
visual_rope,
sparse_params,
)
else:
visual_embed = visual_transformer_block(
visual_embed, text_embed, time_embed, visual_rope, sparse_params
)
return visual_embed
def prepare_output(
self, visual_embed, visual_shape, to_fractal, text_embed, time_embed
):
"""Prepare the final output"""
visual_embed = fractal_unflatten(
visual_embed, visual_shape, block_mask=to_fractal
)
x = self.out_layer(visual_embed, text_embed, time_embed)
return x
def forward(
self,
hidden_states: torch.FloatTensor, # x
@@ -881,32 +773,49 @@ class Kandinsky5Transformer3DModel(
time = timestep
pooled_text_embed = pooled_projections
# Prepare text embeddings and related components
text_embed, time_embed, text_rope, visual_embed = self.prepare_text_embeddings(
text_embed, time, pooled_text_embed, x, text_rope_pos
text_embed = self.text_embeddings(text_embed)
time_embed = self.time_embeddings(time)
time_embed = time_embed + self.pooled_text_embeddings(pooled_text_embed)
visual_embed = self.visual_embeddings(x)
text_rope = self.text_rope_embeddings(text_rope_pos)
text_rope = text_rope.unsqueeze(dim=0)
for text_transformer_block in self.text_transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
text_embed = self._gradient_checkpointing_func(
text_transformer_block, text_embed, time_embed, text_rope
)
else:
text_embed = text_transformer_block(text_embed, time_embed, text_rope)
visual_shape = visual_embed.shape[:-1]
visual_rope = self.visual_rope_embeddings(
visual_shape, visual_rope_pos, scale_factor
)
to_fractal = sparse_params["to_fractal"] if sparse_params is not None else False
visual_embed, visual_rope = fractal_flatten(
visual_embed, visual_rope, visual_shape, block_mask=to_fractal
)
# Process text through transformer blocks
text_embed = self.process_text_transformer_blocks(
text_embed, time_embed, text_rope
)
for visual_transformer_block in self.visual_transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
visual_embed = self._gradient_checkpointing_func(
visual_transformer_block,
visual_embed,
text_embed,
time_embed,
visual_rope,
sparse_params,
)
else:
visual_embed = visual_transformer_block(
visual_embed, text_embed, time_embed, visual_rope, sparse_params
)
# Prepare visual embeddings and related components
visual_embed, visual_shape, to_fractal, visual_rope = (
self.prepare_visual_embeddings(
visual_embed, visual_rope_pos, scale_factor, sparse_params
)
)
# Process visual through transformer blocks
visual_embed = self.process_visual_transformer_blocks(
visual_embed, text_embed, time_embed, visual_rope, sparse_params
)
# Prepare final output
x = self.prepare_output(
visual_embed, visual_shape, to_fractal, text_embed, time_embed
visual_embed = fractal_unflatten(
visual_embed, visual_shape, block_mask=to_fractal
)
x = self.out_layer(visual_embed, text_embed, time_embed)
if not return_dict:
return x