mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
remove oneline function
This commit is contained in:
@@ -174,16 +174,6 @@ def nablaT_v2(
|
||||
)
|
||||
|
||||
|
||||
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
||||
def apply_scale_shift_norm(norm, x, scale, shift):
|
||||
return (norm(x) * (scale + 1.0) + shift).to(torch.bfloat16)
|
||||
|
||||
|
||||
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
||||
def apply_gate_sum(x, out, gate):
|
||||
return (x + gate * out).to(torch.bfloat16)
|
||||
|
||||
|
||||
@torch.autocast(device_type="cuda", enabled=False)
|
||||
def apply_rotary(x, rope):
|
||||
x_ = x.reshape(*x.shape[:-1], -1, 1, 2).to(torch.float32)
|
||||
@@ -327,6 +317,8 @@ class Kandinsky5Modulation(nn.Module):
|
||||
super().__init__()
|
||||
self.activation = nn.SiLU()
|
||||
self.out_layer = nn.Linear(time_dim, num_params * model_dim)
|
||||
self.out_layer.weight.data.zero_()
|
||||
self.out_layer.bias.data.zero_()
|
||||
|
||||
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
||||
def forward(self, x):
|
||||
@@ -585,12 +577,9 @@ class Kandinsky5OutLayer(nn.Module):
|
||||
shift, scale = torch.chunk(
|
||||
self.modulation(time_embed).unsqueeze(dim=1), 2, dim=-1
|
||||
)
|
||||
visual_embed = apply_scale_shift_norm(
|
||||
self.norm,
|
||||
visual_embed,
|
||||
scale[:, None, None],
|
||||
shift[:, None, None],
|
||||
).type_as(visual_embed)
|
||||
|
||||
visual_embed = (self.norm(visual_embed.float()) * (scale.float()[:, None, None] + 1.0) + shift.float()[:, None, None]).type_as(visual_embed)
|
||||
|
||||
x = self.out_layer(visual_embed)
|
||||
|
||||
batch_size, duration, height, width, _ = x.shape
|
||||
@@ -629,17 +618,59 @@ class Kandinsky5TransformerEncoderBlock(nn.Module):
|
||||
self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1
|
||||
)
|
||||
shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1)
|
||||
out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift)
|
||||
out = (self.self_attention_norm(x.float()) * (scale.float() + 1.0) + shift.float()).type_as(x)
|
||||
out = self.self_attention(out, rope)
|
||||
x = apply_gate_sum(x, out, gate)
|
||||
x = (x.float() + gate.float() * out.float()).type_as(x)
|
||||
|
||||
shift, scale, gate = torch.chunk(ff_params, 3, dim=-1)
|
||||
out = apply_scale_shift_norm(self.feed_forward_norm, x, scale, shift)
|
||||
out = (self.feed_forward_norm(x.float()) * (scale.float() + 1.0) + shift.float()).type_as(x)
|
||||
out = self.feed_forward(out)
|
||||
x = apply_gate_sum(x, out, gate)
|
||||
x = (x.float() + gate.float() * out.float()).type_as(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
# class Kandinsky5TransformerDecoderBlock(nn.Module):
|
||||
# def __init__(self, model_dim, time_dim, ff_dim, head_dim):
|
||||
# super().__init__()
|
||||
# self.visual_modulation = Kandinsky5Modulation(time_dim, model_dim, 9)
|
||||
|
||||
# self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
|
||||
# self.self_attention = Kandinsky5MultiheadSelfAttentionDec(model_dim, head_dim)
|
||||
|
||||
# self.cross_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
|
||||
# self.cross_attention = Kandinsky5MultiheadCrossAttention(model_dim, head_dim)
|
||||
|
||||
# self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
|
||||
# self.feed_forward = Kandinsky5FeedForward(model_dim, ff_dim)
|
||||
|
||||
# def forward(self, visual_embed, text_embed, time_embed, rope, sparse_params):
|
||||
# self_attn_params, cross_attn_params, ff_params = torch.chunk(
|
||||
# self.visual_modulation(time_embed).unsqueeze(dim=1), 3, dim=-1
|
||||
# )
|
||||
# shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1)
|
||||
# visual_out = apply_scale_shift_norm(
|
||||
# self.self_attention_norm, visual_embed, scale, shift
|
||||
# )
|
||||
# visual_out = self.self_attention(visual_out, rope, sparse_params)
|
||||
# visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||
|
||||
# shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1)
|
||||
# visual_out = apply_scale_shift_norm(
|
||||
# self.cross_attention_norm, visual_embed, scale, shift
|
||||
# )
|
||||
# visual_out = self.cross_attention(visual_out, text_embed)
|
||||
# visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||
|
||||
# shift, scale, gate = torch.chunk(ff_params, 3, dim=-1)
|
||||
# visual_out = apply_scale_shift_norm(
|
||||
# self.feed_forward_norm, visual_embed, scale, shift
|
||||
# )
|
||||
# visual_out = self.feed_forward(visual_out)
|
||||
# visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||
# return visual_embed
|
||||
|
||||
|
||||
class Kandinsky5TransformerDecoderBlock(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, ff_dim, head_dim):
|
||||
super().__init__()
|
||||
@@ -658,26 +689,22 @@ class Kandinsky5TransformerDecoderBlock(nn.Module):
|
||||
self_attn_params, cross_attn_params, ff_params = torch.chunk(
|
||||
self.visual_modulation(time_embed).unsqueeze(dim=1), 3, dim=-1
|
||||
)
|
||||
|
||||
shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1)
|
||||
visual_out = apply_scale_shift_norm(
|
||||
self.self_attention_norm, visual_embed, scale, shift
|
||||
)
|
||||
visual_out = (self.self_attention_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as(visual_embed)
|
||||
visual_out = self.self_attention(visual_out, rope, sparse_params)
|
||||
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||
visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed)
|
||||
|
||||
shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1)
|
||||
visual_out = apply_scale_shift_norm(
|
||||
self.cross_attention_norm, visual_embed, scale, shift
|
||||
)
|
||||
visual_out = (self.cross_attention_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as(visual_embed)
|
||||
visual_out = self.cross_attention(visual_out, text_embed)
|
||||
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||
visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed)
|
||||
|
||||
shift, scale, gate = torch.chunk(ff_params, 3, dim=-1)
|
||||
visual_out = apply_scale_shift_norm(
|
||||
self.feed_forward_norm, visual_embed, scale, shift
|
||||
)
|
||||
visual_out = (self.feed_forward_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as(visual_embed)
|
||||
visual_out = self.feed_forward(visual_out)
|
||||
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||
visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed)
|
||||
|
||||
return visual_embed
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user