1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

remove oneline function

This commit is contained in:
leffff
2025-10-16 08:46:34 +00:00
parent 600e9d6b87
commit 31a1474378

View File

@@ -174,16 +174,6 @@ def nablaT_v2(
)
@torch.autocast(device_type="cuda", dtype=torch.float32)
def apply_scale_shift_norm(norm, x, scale, shift):
return (norm(x) * (scale + 1.0) + shift).to(torch.bfloat16)
@torch.autocast(device_type="cuda", dtype=torch.float32)
def apply_gate_sum(x, out, gate):
return (x + gate * out).to(torch.bfloat16)
@torch.autocast(device_type="cuda", enabled=False)
def apply_rotary(x, rope):
x_ = x.reshape(*x.shape[:-1], -1, 1, 2).to(torch.float32)
@@ -327,6 +317,8 @@ class Kandinsky5Modulation(nn.Module):
super().__init__()
self.activation = nn.SiLU()
self.out_layer = nn.Linear(time_dim, num_params * model_dim)
self.out_layer.weight.data.zero_()
self.out_layer.bias.data.zero_()
@torch.autocast(device_type="cuda", dtype=torch.float32)
def forward(self, x):
@@ -585,12 +577,9 @@ class Kandinsky5OutLayer(nn.Module):
shift, scale = torch.chunk(
self.modulation(time_embed).unsqueeze(dim=1), 2, dim=-1
)
visual_embed = apply_scale_shift_norm(
self.norm,
visual_embed,
scale[:, None, None],
shift[:, None, None],
).type_as(visual_embed)
visual_embed = (self.norm(visual_embed.float()) * (scale.float()[:, None, None] + 1.0) + shift.float()[:, None, None]).type_as(visual_embed)
x = self.out_layer(visual_embed)
batch_size, duration, height, width, _ = x.shape
@@ -629,17 +618,59 @@ class Kandinsky5TransformerEncoderBlock(nn.Module):
self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1
)
shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1)
out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift)
out = (self.self_attention_norm(x.float()) * (scale.float() + 1.0) + shift.float()).type_as(x)
out = self.self_attention(out, rope)
x = apply_gate_sum(x, out, gate)
x = (x.float() + gate.float() * out.float()).type_as(x)
shift, scale, gate = torch.chunk(ff_params, 3, dim=-1)
out = apply_scale_shift_norm(self.feed_forward_norm, x, scale, shift)
out = (self.feed_forward_norm(x.float()) * (scale.float() + 1.0) + shift.float()).type_as(x)
out = self.feed_forward(out)
x = apply_gate_sum(x, out, gate)
x = (x.float() + gate.float() * out.float()).type_as(x)
return x
# class Kandinsky5TransformerDecoderBlock(nn.Module):
# def __init__(self, model_dim, time_dim, ff_dim, head_dim):
# super().__init__()
# self.visual_modulation = Kandinsky5Modulation(time_dim, model_dim, 9)
# self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
# self.self_attention = Kandinsky5MultiheadSelfAttentionDec(model_dim, head_dim)
# self.cross_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
# self.cross_attention = Kandinsky5MultiheadCrossAttention(model_dim, head_dim)
# self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
# self.feed_forward = Kandinsky5FeedForward(model_dim, ff_dim)
# def forward(self, visual_embed, text_embed, time_embed, rope, sparse_params):
# self_attn_params, cross_attn_params, ff_params = torch.chunk(
# self.visual_modulation(time_embed).unsqueeze(dim=1), 3, dim=-1
# )
# shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1)
# visual_out = apply_scale_shift_norm(
# self.self_attention_norm, visual_embed, scale, shift
# )
# visual_out = self.self_attention(visual_out, rope, sparse_params)
# visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
# shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1)
# visual_out = apply_scale_shift_norm(
# self.cross_attention_norm, visual_embed, scale, shift
# )
# visual_out = self.cross_attention(visual_out, text_embed)
# visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
# shift, scale, gate = torch.chunk(ff_params, 3, dim=-1)
# visual_out = apply_scale_shift_norm(
# self.feed_forward_norm, visual_embed, scale, shift
# )
# visual_out = self.feed_forward(visual_out)
# visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
# return visual_embed
class Kandinsky5TransformerDecoderBlock(nn.Module):
def __init__(self, model_dim, time_dim, ff_dim, head_dim):
super().__init__()
@@ -658,26 +689,22 @@ class Kandinsky5TransformerDecoderBlock(nn.Module):
self_attn_params, cross_attn_params, ff_params = torch.chunk(
self.visual_modulation(time_embed).unsqueeze(dim=1), 3, dim=-1
)
shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1)
visual_out = apply_scale_shift_norm(
self.self_attention_norm, visual_embed, scale, shift
)
visual_out = (self.self_attention_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as(visual_embed)
visual_out = self.self_attention(visual_out, rope, sparse_params)
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed)
shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1)
visual_out = apply_scale_shift_norm(
self.cross_attention_norm, visual_embed, scale, shift
)
visual_out = (self.cross_attention_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as(visual_embed)
visual_out = self.cross_attention(visual_out, text_embed)
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed)
shift, scale, gate = torch.chunk(ff_params, 3, dim=-1)
visual_out = apply_scale_shift_norm(
self.feed_forward_norm, visual_embed, scale, shift
)
visual_out = (self.feed_forward_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as(visual_embed)
visual_out = self.feed_forward(visual_out)
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed)
return visual_embed