From 31a1474378a0ae3fe22bc626f7fe274c99ed30fd Mon Sep 17 00:00:00 2001 From: leffff Date: Thu, 16 Oct 2025 08:46:34 +0000 Subject: [PATCH] remove oneline function --- .../transformers/transformer_kandinsky.py | 91 ++++++++++++------- 1 file changed, 59 insertions(+), 32 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index febe6cff7a..bed1938ae3 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -174,16 +174,6 @@ def nablaT_v2( ) -@torch.autocast(device_type="cuda", dtype=torch.float32) -def apply_scale_shift_norm(norm, x, scale, shift): - return (norm(x) * (scale + 1.0) + shift).to(torch.bfloat16) - - -@torch.autocast(device_type="cuda", dtype=torch.float32) -def apply_gate_sum(x, out, gate): - return (x + gate * out).to(torch.bfloat16) - - @torch.autocast(device_type="cuda", enabled=False) def apply_rotary(x, rope): x_ = x.reshape(*x.shape[:-1], -1, 1, 2).to(torch.float32) @@ -327,6 +317,8 @@ class Kandinsky5Modulation(nn.Module): super().__init__() self.activation = nn.SiLU() self.out_layer = nn.Linear(time_dim, num_params * model_dim) + self.out_layer.weight.data.zero_() + self.out_layer.bias.data.zero_() @torch.autocast(device_type="cuda", dtype=torch.float32) def forward(self, x): @@ -585,12 +577,9 @@ class Kandinsky5OutLayer(nn.Module): shift, scale = torch.chunk( self.modulation(time_embed).unsqueeze(dim=1), 2, dim=-1 ) - visual_embed = apply_scale_shift_norm( - self.norm, - visual_embed, - scale[:, None, None], - shift[:, None, None], - ).type_as(visual_embed) + + visual_embed = (self.norm(visual_embed.float()) * (scale.float()[:, None, None] + 1.0) + shift.float()[:, None, None]).type_as(visual_embed) + x = self.out_layer(visual_embed) batch_size, duration, height, width, _ = x.shape @@ -629,17 +618,59 @@ class Kandinsky5TransformerEncoderBlock(nn.Module): self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1 ) shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) - out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift) + out = (self.self_attention_norm(x.float()) * (scale.float() + 1.0) + shift.float()).type_as(x) out = self.self_attention(out, rope) - x = apply_gate_sum(x, out, gate) + x = (x.float() + gate.float() * out.float()).type_as(x) shift, scale, gate = torch.chunk(ff_params, 3, dim=-1) - out = apply_scale_shift_norm(self.feed_forward_norm, x, scale, shift) + out = (self.feed_forward_norm(x.float()) * (scale.float() + 1.0) + shift.float()).type_as(x) out = self.feed_forward(out) - x = apply_gate_sum(x, out, gate) + x = (x.float() + gate.float() * out.float()).type_as(x) + return x +# class Kandinsky5TransformerDecoderBlock(nn.Module): +# def __init__(self, model_dim, time_dim, ff_dim, head_dim): +# super().__init__() +# self.visual_modulation = Kandinsky5Modulation(time_dim, model_dim, 9) + +# self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False) +# self.self_attention = Kandinsky5MultiheadSelfAttentionDec(model_dim, head_dim) + +# self.cross_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False) +# self.cross_attention = Kandinsky5MultiheadCrossAttention(model_dim, head_dim) + +# self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False) +# self.feed_forward = Kandinsky5FeedForward(model_dim, ff_dim) + +# def forward(self, visual_embed, text_embed, time_embed, rope, sparse_params): +# self_attn_params, cross_attn_params, ff_params = torch.chunk( +# self.visual_modulation(time_embed).unsqueeze(dim=1), 3, dim=-1 +# ) +# shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) +# visual_out = apply_scale_shift_norm( +# self.self_attention_norm, visual_embed, scale, shift +# ) +# visual_out = self.self_attention(visual_out, rope, sparse_params) +# visual_embed = apply_gate_sum(visual_embed, visual_out, gate) + +# shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1) +# visual_out = apply_scale_shift_norm( +# self.cross_attention_norm, visual_embed, scale, shift +# ) +# visual_out = self.cross_attention(visual_out, text_embed) +# visual_embed = apply_gate_sum(visual_embed, visual_out, gate) + +# shift, scale, gate = torch.chunk(ff_params, 3, dim=-1) +# visual_out = apply_scale_shift_norm( +# self.feed_forward_norm, visual_embed, scale, shift +# ) +# visual_out = self.feed_forward(visual_out) +# visual_embed = apply_gate_sum(visual_embed, visual_out, gate) +# return visual_embed + + class Kandinsky5TransformerDecoderBlock(nn.Module): def __init__(self, model_dim, time_dim, ff_dim, head_dim): super().__init__() @@ -658,26 +689,22 @@ class Kandinsky5TransformerDecoderBlock(nn.Module): self_attn_params, cross_attn_params, ff_params = torch.chunk( self.visual_modulation(time_embed).unsqueeze(dim=1), 3, dim=-1 ) + shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) - visual_out = apply_scale_shift_norm( - self.self_attention_norm, visual_embed, scale, shift - ) + visual_out = (self.self_attention_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as(visual_embed) visual_out = self.self_attention(visual_out, rope, sparse_params) - visual_embed = apply_gate_sum(visual_embed, visual_out, gate) + visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed) shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1) - visual_out = apply_scale_shift_norm( - self.cross_attention_norm, visual_embed, scale, shift - ) + visual_out = (self.cross_attention_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as(visual_embed) visual_out = self.cross_attention(visual_out, text_embed) - visual_embed = apply_gate_sum(visual_embed, visual_out, gate) + visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed) shift, scale, gate = torch.chunk(ff_params, 3, dim=-1) - visual_out = apply_scale_shift_norm( - self.feed_forward_norm, visual_embed, scale, shift - ) + visual_out = (self.feed_forward_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as(visual_embed) visual_out = self.feed_forward(visual_out) - visual_embed = apply_gate_sum(visual_embed, visual_out, gate) + visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed) + return visual_embed